40e48faf3ff11cf051d4cebb039fcce984a59b37
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / mobilenet / mobilenet_v2_test.py
1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Tests for mobilenet_v2."""
16
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20 import copy
21 import tensorflow as tf
22 from nets.mobilenet import conv_blocks as ops
23 from nets.mobilenet import mobilenet
24 from nets.mobilenet import mobilenet_v2
25
26
27 slim = tf.contrib.slim
28
29
30 def find_ops(optype):
31   """Find ops of a given type in graphdef or a graph.
32
33   Args:
34     optype: operation type (e.g. Conv2D)
35   Returns:
36      List of operations.
37   """
38   gd = tf.get_default_graph()
39   return [var for var in gd.get_operations() if var.type == optype]
40
41
42 class MobilenetV2Test(tf.test.TestCase):
43
44   def setUp(self):
45     tf.reset_default_graph()
46
47   def testCreation(self):
48     spec = dict(mobilenet_v2.V2_DEF)
49     _, ep = mobilenet.mobilenet(
50         tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec)
51     num_convs = len(find_ops('Conv2D'))
52
53     # This is mostly a sanity test. No deep reason for these particular
54     # constants.
55     #
56     # All but first 2 and last one have  two convolutions, and there is one
57     # extra conv that is not in the spec. (logits)
58     self.assertEqual(num_convs, len(spec['spec']) * 2 - 2)
59     # Check that depthwise are exposed.
60     for i in range(2, 17):
61       self.assertIn('layer_%d/depthwise_output' % i, ep)
62
63   def testCreationNoClasses(self):
64     spec = copy.deepcopy(mobilenet_v2.V2_DEF)
65     net, ep = mobilenet.mobilenet(
66         tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec,
67         num_classes=None)
68     self.assertIs(net, ep['global_pool'])
69
70   def testImageSizes(self):
71     for input_size, output_size in [(224, 7), (192, 6), (160, 5),
72                                     (128, 4), (96, 3)]:
73       tf.reset_default_graph()
74       _, ep = mobilenet_v2.mobilenet(
75           tf.placeholder(tf.float32, (10, input_size, input_size, 3)))
76
77       self.assertEqual(ep['layer_18/output'].get_shape().as_list()[1:3],
78                        [output_size] * 2)
79
80   def testWithSplits(self):
81     spec = copy.deepcopy(mobilenet_v2.V2_DEF)
82     spec['overrides'] = {
83         (ops.expanded_conv,): dict(split_expansion=2),
84     }
85     _, _ = mobilenet.mobilenet(
86         tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec)
87     num_convs = len(find_ops('Conv2D'))
88     # All but 3 op has 3 conv operatore, the remainign 3 have one
89     # and there is one unaccounted.
90     self.assertEqual(num_convs, len(spec['spec']) * 3 - 5)
91
92   def testWithOutputStride8(self):
93     out, _ = mobilenet.mobilenet_base(
94         tf.placeholder(tf.float32, (10, 224, 224, 16)),
95         conv_defs=mobilenet_v2.V2_DEF,
96         output_stride=8,
97         scope='MobilenetV2')
98     self.assertEqual(out.get_shape().as_list()[1:3], [28, 28])
99
100   def testDivisibleBy(self):
101     tf.reset_default_graph()
102     mobilenet_v2.mobilenet(
103         tf.placeholder(tf.float32, (10, 224, 224, 16)),
104         conv_defs=mobilenet_v2.V2_DEF,
105         divisible_by=16,
106         min_depth=32)
107     s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')]
108     s = set(s)
109     self.assertSameElements([32, 64, 96, 160, 192, 320, 384, 576, 960, 1280,
110                              1001], s)
111
112   def testDivisibleByWithArgScope(self):
113     tf.reset_default_graph()
114     # Verifies that depth_multiplier arg scope actually works
115     # if no default min_depth is provided.
116     with slim.arg_scope((mobilenet.depth_multiplier,), min_depth=32):
117       mobilenet_v2.mobilenet(
118           tf.placeholder(tf.float32, (10, 224, 224, 2)),
119           conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.1)
120       s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')]
121       s = set(s)
122       self.assertSameElements(s, [32, 192, 128, 1001])
123
124   def testFineGrained(self):
125     tf.reset_default_graph()
126     # Verifies that depth_multiplier arg scope actually works
127     # if no default min_depth is provided.
128
129     mobilenet_v2.mobilenet(
130         tf.placeholder(tf.float32, (10, 224, 224, 2)),
131         conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.01,
132         finegrain_classification_mode=True)
133     s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')]
134     s = set(s)
135     # All convolutions will be 8->48, except for the last one.
136     self.assertSameElements(s, [8, 48, 1001, 1280])
137
138   def testMobilenetBase(self):
139     tf.reset_default_graph()
140     # Verifies that mobilenet_base returns pre-pooling layer.
141     with slim.arg_scope((mobilenet.depth_multiplier,), min_depth=32):
142       net, _ = mobilenet_v2.mobilenet_base(
143           tf.placeholder(tf.float32, (10, 224, 224, 16)),
144           conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.1)
145       self.assertEqual(net.get_shape().as_list(), [10, 7, 7, 128])
146
147   def testWithOutputStride16(self):
148     tf.reset_default_graph()
149     out, _ = mobilenet.mobilenet_base(
150         tf.placeholder(tf.float32, (10, 224, 224, 16)),
151         conv_defs=mobilenet_v2.V2_DEF,
152         output_stride=16)
153     self.assertEqual(out.get_shape().as_list()[1:3], [14, 14])
154
155   def testWithOutputStride8AndExplicitPadding(self):
156     tf.reset_default_graph()
157     out, _ = mobilenet.mobilenet_base(
158         tf.placeholder(tf.float32, (10, 224, 224, 16)),
159         conv_defs=mobilenet_v2.V2_DEF,
160         output_stride=8,
161         use_explicit_padding=True,
162         scope='MobilenetV2')
163     self.assertEqual(out.get_shape().as_list()[1:3], [28, 28])
164
165   def testWithOutputStride16AndExplicitPadding(self):
166     tf.reset_default_graph()
167     out, _ = mobilenet.mobilenet_base(
168         tf.placeholder(tf.float32, (10, 224, 224, 16)),
169         conv_defs=mobilenet_v2.V2_DEF,
170         output_stride=16,
171         use_explicit_padding=True)
172     self.assertEqual(out.get_shape().as_list()[1:3], [14, 14])
173
174
175 if __name__ == '__main__':
176   tf.test.main()