pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / mobilenet_v2.py
diff --git a/example-apps/PDD/pcb-defect-detection/libs/networks/mobilenet_v2.py b/example-apps/PDD/pcb-defect-detection/libs/networks/mobilenet_v2.py
new file mode 100755 (executable)
index 0000000..eced27c
--- /dev/null
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, print_function, division
+import tensorflow.contrib.slim as slim
+import tensorflow as tf
+
+from libs.networks.mobilenet import mobilenet_v2
+from libs.networks.mobilenet.mobilenet import training_scope
+from libs.networks.mobilenet.mobilenet_v2 import op
+from libs.networks.mobilenet.mobilenet_v2 import ops
+expand_input = ops.expand_input_by_factor
+
+V2_BASE_DEF = dict(
+    defaults={
+        # Note: these parameters of batch norm affect the architecture
+        # that's why they are here and not in training_scope.
+        (slim.batch_norm,): {'center': True, 'scale': True},
+        (slim.conv2d, slim.fully_connected, slim.separable_conv2d): {
+            'normalizer_fn': slim.batch_norm, 'activation_fn': tf.nn.relu6
+        },
+        (ops.expanded_conv,): {
+            'expansion_size': expand_input(6),
+            'split_expansion': 1,
+            'normalizer_fn': slim.batch_norm,
+            'residual': True
+        },
+        (slim.conv2d, slim.separable_conv2d): {'padding': 'SAME'}
+    },
+    spec=[
+        op(slim.conv2d, stride=2, num_outputs=32, kernel_size=[3, 3]),
+        op(ops.expanded_conv,
+           expansion_size=expand_input(1, divisible_by=1),
+           num_outputs=16, scope='expanded_conv'),
+        op(ops.expanded_conv, stride=2, num_outputs=24, scope='expanded_conv_1'),
+        op(ops.expanded_conv, stride=1, num_outputs=24, scope='expanded_conv_2'),
+        op(ops.expanded_conv, stride=2, num_outputs=32, scope='expanded_conv_3'),
+        op(ops.expanded_conv, stride=1, num_outputs=32, scope='expanded_conv_4'),
+        op(ops.expanded_conv, stride=1, num_outputs=32, scope='expanded_conv_5'),
+        op(ops.expanded_conv, stride=2, num_outputs=64, scope='expanded_conv_6'),
+        op(ops.expanded_conv, stride=1, num_outputs=64, scope='expanded_conv_7'),
+        op(ops.expanded_conv, stride=1, num_outputs=64, scope='expanded_conv_8'),
+        op(ops.expanded_conv, stride=1, num_outputs=64, scope='expanded_conv_9'),
+        op(ops.expanded_conv, stride=1, num_outputs=96, scope='expanded_conv_10'),
+        op(ops.expanded_conv, stride=1, num_outputs=96, scope='expanded_conv_11'),
+        op(ops.expanded_conv, stride=1, num_outputs=96, scope='expanded_conv_12')
+    ],
+)
+
+
+V2_HEAD_DEF = dict(
+    defaults={
+        # Note: these parameters of batch norm affect the architecture
+        # that's why they are here and not in training_scope.
+        (slim.batch_norm,): {'center': True, 'scale': True},
+        (slim.conv2d, slim.fully_connected, slim.separable_conv2d): {
+            'normalizer_fn': slim.batch_norm, 'activation_fn': tf.nn.relu6
+        },
+        (ops.expanded_conv,): {
+            'expansion_size': expand_input(6),
+            'split_expansion': 1,
+            'normalizer_fn': slim.batch_norm,
+            'residual': True
+        },
+        (slim.conv2d, slim.separable_conv2d): {'padding': 'SAME'}
+    },
+    spec=[
+        op(ops.expanded_conv, stride=2, num_outputs=160, scope='expanded_conv_13'),
+        op(ops.expanded_conv, stride=1, num_outputs=160, scope='expanded_conv_14'),
+        op(ops.expanded_conv, stride=1, num_outputs=160, scope='expanded_conv_15'),
+        op(ops.expanded_conv, stride=1, num_outputs=320, scope='expanded_conv_16'),
+        op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=1280, scope='Conv_1')
+    ],
+)
+def mobilenetv2_scope(is_training=True,
+                      trainable=True,
+                      weight_decay=0.00004,
+                      stddev=0.09,
+                      dropout_keep_prob=0.8,
+                      bn_decay=0.997):
+  """Defines Mobilenet training scope.
+  In default. We do not use BN
+
+  ReWrite the scope.
+  """
+  batch_norm_params = {
+      'is_training': False,
+      'trainable': False,
+      'decay': bn_decay,
+  }
+  with slim.arg_scope(training_scope(is_training=is_training, weight_decay=weight_decay)):
+      with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.separable_conv2d],
+                          trainable=trainable):
+          with slim.arg_scope([slim.batch_norm], **batch_norm_params) as sc:
+              return sc
+
+
+
+def mobilenetv2_base(img_batch, is_training=True):
+
+    with slim.arg_scope(mobilenetv2_scope(is_training=is_training, trainable=True)):
+
+        feature_to_crop, endpoints = mobilenet_v2.mobilenet_base(input_tensor=img_batch,
+                                                      num_classes=None,
+                                                      is_training=False,
+                                                      depth_multiplier=1.0,
+                                                      scope='MobilenetV2',
+                                                      conv_defs=V2_BASE_DEF,
+                                                      finegrain_classification_mode=False)
+
+        # feature_to_crop = tf.Print(feature_to_crop, [tf.shape(feature_to_crop)], summarize=10, message='rpn_shape')
+        return feature_to_crop
+
+
+def mobilenetv2_head(inputs, is_training=True):
+    with slim.arg_scope(mobilenetv2_scope(is_training=is_training, trainable=True)):
+        net, _ = mobilenet_v2.mobilenet(input_tensor=inputs,
+                                        num_classes=None,
+                                        is_training=False,
+                                        depth_multiplier=1.0,
+                                        scope='MobilenetV2',
+                                        conv_defs=V2_HEAD_DEF,
+                                        finegrain_classification_mode=False)
+
+        net = tf.squeeze(net, [1, 2])
+
+        return net
\ No newline at end of file