pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / mobilenet_v2.py
1 # -*- coding: utf-8 -*-
2
3 from __future__ import absolute_import, print_function, division
4 import tensorflow.contrib.slim as slim
5 import tensorflow as tf
6
7 from libs.networks.mobilenet import mobilenet_v2
8 from libs.networks.mobilenet.mobilenet import training_scope
9 from libs.networks.mobilenet.mobilenet_v2 import op
10 from libs.networks.mobilenet.mobilenet_v2 import ops
11 expand_input = ops.expand_input_by_factor
12
13 V2_BASE_DEF = dict(
14     defaults={
15         # Note: these parameters of batch norm affect the architecture
16         # that's why they are here and not in training_scope.
17         (slim.batch_norm,): {'center': True, 'scale': True},
18         (slim.conv2d, slim.fully_connected, slim.separable_conv2d): {
19             'normalizer_fn': slim.batch_norm, 'activation_fn': tf.nn.relu6
20         },
21         (ops.expanded_conv,): {
22             'expansion_size': expand_input(6),
23             'split_expansion': 1,
24             'normalizer_fn': slim.batch_norm,
25             'residual': True
26         },
27         (slim.conv2d, slim.separable_conv2d): {'padding': 'SAME'}
28     },
29     spec=[
30         op(slim.conv2d, stride=2, num_outputs=32, kernel_size=[3, 3]),
31         op(ops.expanded_conv,
32            expansion_size=expand_input(1, divisible_by=1),
33            num_outputs=16, scope='expanded_conv'),
34         op(ops.expanded_conv, stride=2, num_outputs=24, scope='expanded_conv_1'),
35         op(ops.expanded_conv, stride=1, num_outputs=24, scope='expanded_conv_2'),
36         op(ops.expanded_conv, stride=2, num_outputs=32, scope='expanded_conv_3'),
37         op(ops.expanded_conv, stride=1, num_outputs=32, scope='expanded_conv_4'),
38         op(ops.expanded_conv, stride=1, num_outputs=32, scope='expanded_conv_5'),
39         op(ops.expanded_conv, stride=2, num_outputs=64, scope='expanded_conv_6'),
40         op(ops.expanded_conv, stride=1, num_outputs=64, scope='expanded_conv_7'),
41         op(ops.expanded_conv, stride=1, num_outputs=64, scope='expanded_conv_8'),
42         op(ops.expanded_conv, stride=1, num_outputs=64, scope='expanded_conv_9'),
43         op(ops.expanded_conv, stride=1, num_outputs=96, scope='expanded_conv_10'),
44         op(ops.expanded_conv, stride=1, num_outputs=96, scope='expanded_conv_11'),
45         op(ops.expanded_conv, stride=1, num_outputs=96, scope='expanded_conv_12')
46     ],
47 )
48
49
50 V2_HEAD_DEF = dict(
51     defaults={
52         # Note: these parameters of batch norm affect the architecture
53         # that's why they are here and not in training_scope.
54         (slim.batch_norm,): {'center': True, 'scale': True},
55         (slim.conv2d, slim.fully_connected, slim.separable_conv2d): {
56             'normalizer_fn': slim.batch_norm, 'activation_fn': tf.nn.relu6
57         },
58         (ops.expanded_conv,): {
59             'expansion_size': expand_input(6),
60             'split_expansion': 1,
61             'normalizer_fn': slim.batch_norm,
62             'residual': True
63         },
64         (slim.conv2d, slim.separable_conv2d): {'padding': 'SAME'}
65     },
66     spec=[
67         op(ops.expanded_conv, stride=2, num_outputs=160, scope='expanded_conv_13'),
68         op(ops.expanded_conv, stride=1, num_outputs=160, scope='expanded_conv_14'),
69         op(ops.expanded_conv, stride=1, num_outputs=160, scope='expanded_conv_15'),
70         op(ops.expanded_conv, stride=1, num_outputs=320, scope='expanded_conv_16'),
71         op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=1280, scope='Conv_1')
72     ],
73 )
74 def mobilenetv2_scope(is_training=True,
75                       trainable=True,
76                       weight_decay=0.00004,
77                       stddev=0.09,
78                       dropout_keep_prob=0.8,
79                       bn_decay=0.997):
80   """Defines Mobilenet training scope.
81   In default. We do not use BN
82
83   ReWrite the scope.
84   """
85   batch_norm_params = {
86       'is_training': False,
87       'trainable': False,
88       'decay': bn_decay,
89   }
90   with slim.arg_scope(training_scope(is_training=is_training, weight_decay=weight_decay)):
91       with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.separable_conv2d],
92                           trainable=trainable):
93           with slim.arg_scope([slim.batch_norm], **batch_norm_params) as sc:
94               return sc
95
96
97
98 def mobilenetv2_base(img_batch, is_training=True):
99
100     with slim.arg_scope(mobilenetv2_scope(is_training=is_training, trainable=True)):
101
102         feature_to_crop, endpoints = mobilenet_v2.mobilenet_base(input_tensor=img_batch,
103                                                       num_classes=None,
104                                                       is_training=False,
105                                                       depth_multiplier=1.0,
106                                                       scope='MobilenetV2',
107                                                       conv_defs=V2_BASE_DEF,
108                                                       finegrain_classification_mode=False)
109
110         # feature_to_crop = tf.Print(feature_to_crop, [tf.shape(feature_to_crop)], summarize=10, message='rpn_shape')
111         return feature_to_crop
112
113
114 def mobilenetv2_head(inputs, is_training=True):
115     with slim.arg_scope(mobilenetv2_scope(is_training=is_training, trainable=True)):
116         net, _ = mobilenet_v2.mobilenet(input_tensor=inputs,
117                                         num_classes=None,
118                                         is_training=False,
119                                         depth_multiplier=1.0,
120                                         scope='MobilenetV2',
121                                         conv_defs=V2_HEAD_DEF,
122                                         finegrain_classification_mode=False)
123
124         net = tf.squeeze(net, [1, 2])
125
126         return net