44e66446baa42f49e164131eb4c1a97b46a9918d
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / slim_nets / mobilenet_v1_test.py
1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # =============================================================================
15 """Tests for MobileNet v1."""
16
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20
21 import numpy as np
22 import tensorflow as tf
23
24 from nets import mobilenet_v1
25
26 slim = tf.contrib.slim
27
28
29 class MobilenetV1Test(tf.test.TestCase):
30
31   def testBuildClassificationNetwork(self):
32     batch_size = 5
33     height, width = 224, 224
34     num_classes = 1000
35
36     inputs = tf.random_uniform((batch_size, height, width, 3))
37     logits, end_points = mobilenet_v1.mobilenet_v1(inputs, num_classes)
38     self.assertTrue(logits.op.name.startswith('MobilenetV1/Logits'))
39     self.assertListEqual(logits.get_shape().as_list(),
40                          [batch_size, num_classes])
41     self.assertTrue('Predictions' in end_points)
42     self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
43                          [batch_size, num_classes])
44
45   def testBuildBaseNetwork(self):
46     batch_size = 5
47     height, width = 224, 224
48
49     inputs = tf.random_uniform((batch_size, height, width, 3))
50     net, end_points = mobilenet_v1.mobilenet_v1_base(inputs)
51     self.assertTrue(net.op.name.startswith('MobilenetV1/Conv2d_13'))
52     self.assertListEqual(net.get_shape().as_list(),
53                          [batch_size, 7, 7, 1024])
54     expected_endpoints = ['Conv2d_0',
55                           'Conv2d_1_depthwise', 'Conv2d_1_pointwise',
56                           'Conv2d_2_depthwise', 'Conv2d_2_pointwise',
57                           'Conv2d_3_depthwise', 'Conv2d_3_pointwise',
58                           'Conv2d_4_depthwise', 'Conv2d_4_pointwise',
59                           'Conv2d_5_depthwise', 'Conv2d_5_pointwise',
60                           'Conv2d_6_depthwise', 'Conv2d_6_pointwise',
61                           'Conv2d_7_depthwise', 'Conv2d_7_pointwise',
62                           'Conv2d_8_depthwise', 'Conv2d_8_pointwise',
63                           'Conv2d_9_depthwise', 'Conv2d_9_pointwise',
64                           'Conv2d_10_depthwise', 'Conv2d_10_pointwise',
65                           'Conv2d_11_depthwise', 'Conv2d_11_pointwise',
66                           'Conv2d_12_depthwise', 'Conv2d_12_pointwise',
67                           'Conv2d_13_depthwise', 'Conv2d_13_pointwise']
68     self.assertItemsEqual(end_points.keys(), expected_endpoints)
69
70   def testBuildOnlyUptoFinalEndpoint(self):
71     batch_size = 5
72     height, width = 224, 224
73     endpoints = ['Conv2d_0',
74                  'Conv2d_1_depthwise', 'Conv2d_1_pointwise',
75                  'Conv2d_2_depthwise', 'Conv2d_2_pointwise',
76                  'Conv2d_3_depthwise', 'Conv2d_3_pointwise',
77                  'Conv2d_4_depthwise', 'Conv2d_4_pointwise',
78                  'Conv2d_5_depthwise', 'Conv2d_5_pointwise',
79                  'Conv2d_6_depthwise', 'Conv2d_6_pointwise',
80                  'Conv2d_7_depthwise', 'Conv2d_7_pointwise',
81                  'Conv2d_8_depthwise', 'Conv2d_8_pointwise',
82                  'Conv2d_9_depthwise', 'Conv2d_9_pointwise',
83                  'Conv2d_10_depthwise', 'Conv2d_10_pointwise',
84                  'Conv2d_11_depthwise', 'Conv2d_11_pointwise',
85                  'Conv2d_12_depthwise', 'Conv2d_12_pointwise',
86                  'Conv2d_13_depthwise', 'Conv2d_13_pointwise']
87     for index, endpoint in enumerate(endpoints):
88       with tf.Graph().as_default():
89         inputs = tf.random_uniform((batch_size, height, width, 3))
90         out_tensor, end_points = mobilenet_v1.mobilenet_v1_base(
91             inputs, final_endpoint=endpoint)
92         self.assertTrue(out_tensor.op.name.startswith(
93             'MobilenetV1/' + endpoint))
94         self.assertItemsEqual(endpoints[:index+1], end_points)
95
96   def testBuildCustomNetworkUsingConvDefs(self):
97     batch_size = 5
98     height, width = 224, 224
99     conv_defs = [
100         mobilenet_v1.Conv(kernel=[3, 3], stride=2, depth=32),
101         mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=64),
102         mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=128),
103         mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512)
104     ]
105
106     inputs = tf.random_uniform((batch_size, height, width, 3))
107     net, end_points = mobilenet_v1.mobilenet_v1_base(
108         inputs, final_endpoint='Conv2d_3_pointwise', conv_defs=conv_defs)
109     self.assertTrue(net.op.name.startswith('MobilenetV1/Conv2d_3'))
110     self.assertListEqual(net.get_shape().as_list(),
111                          [batch_size, 56, 56, 512])
112     expected_endpoints = ['Conv2d_0',
113                           'Conv2d_1_depthwise', 'Conv2d_1_pointwise',
114                           'Conv2d_2_depthwise', 'Conv2d_2_pointwise',
115                           'Conv2d_3_depthwise', 'Conv2d_3_pointwise']
116     self.assertItemsEqual(end_points.keys(), expected_endpoints)
117
118   def testBuildAndCheckAllEndPointsUptoConv2d_13(self):
119     batch_size = 5
120     height, width = 224, 224
121
122     inputs = tf.random_uniform((batch_size, height, width, 3))
123     with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
124                         normalizer_fn=slim.batch_norm):
125       _, end_points = mobilenet_v1.mobilenet_v1_base(
126           inputs, final_endpoint='Conv2d_13_pointwise')
127     endpoints_shapes = {'Conv2d_0': [batch_size, 112, 112, 32],
128                         'Conv2d_1_depthwise': [batch_size, 112, 112, 32],
129                         'Conv2d_1_pointwise': [batch_size, 112, 112, 64],
130                         'Conv2d_2_depthwise': [batch_size, 56, 56, 64],
131                         'Conv2d_2_pointwise': [batch_size, 56, 56, 128],
132                         'Conv2d_3_depthwise': [batch_size, 56, 56, 128],
133                         'Conv2d_3_pointwise': [batch_size, 56, 56, 128],
134                         'Conv2d_4_depthwise': [batch_size, 28, 28, 128],
135                         'Conv2d_4_pointwise': [batch_size, 28, 28, 256],
136                         'Conv2d_5_depthwise': [batch_size, 28, 28, 256],
137                         'Conv2d_5_pointwise': [batch_size, 28, 28, 256],
138                         'Conv2d_6_depthwise': [batch_size, 14, 14, 256],
139                         'Conv2d_6_pointwise': [batch_size, 14, 14, 512],
140                         'Conv2d_7_depthwise': [batch_size, 14, 14, 512],
141                         'Conv2d_7_pointwise': [batch_size, 14, 14, 512],
142                         'Conv2d_8_depthwise': [batch_size, 14, 14, 512],
143                         'Conv2d_8_pointwise': [batch_size, 14, 14, 512],
144                         'Conv2d_9_depthwise': [batch_size, 14, 14, 512],
145                         'Conv2d_9_pointwise': [batch_size, 14, 14, 512],
146                         'Conv2d_10_depthwise': [batch_size, 14, 14, 512],
147                         'Conv2d_10_pointwise': [batch_size, 14, 14, 512],
148                         'Conv2d_11_depthwise': [batch_size, 14, 14, 512],
149                         'Conv2d_11_pointwise': [batch_size, 14, 14, 512],
150                         'Conv2d_12_depthwise': [batch_size, 7, 7, 512],
151                         'Conv2d_12_pointwise': [batch_size, 7, 7, 1024],
152                         'Conv2d_13_depthwise': [batch_size, 7, 7, 1024],
153                         'Conv2d_13_pointwise': [batch_size, 7, 7, 1024]}
154     self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
155     for endpoint_name, expected_shape in endpoints_shapes.iteritems():
156       self.assertTrue(endpoint_name in end_points)
157       self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
158                            expected_shape)
159
160   def testOutputStride16BuildAndCheckAllEndPointsUptoConv2d_13(self):
161     batch_size = 5
162     height, width = 224, 224
163     output_stride = 16
164
165     inputs = tf.random_uniform((batch_size, height, width, 3))
166     with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
167                         normalizer_fn=slim.batch_norm):
168       _, end_points = mobilenet_v1.mobilenet_v1_base(
169           inputs, output_stride=output_stride,
170           final_endpoint='Conv2d_13_pointwise')
171     endpoints_shapes = {'Conv2d_0': [batch_size, 112, 112, 32],
172                         'Conv2d_1_depthwise': [batch_size, 112, 112, 32],
173                         'Conv2d_1_pointwise': [batch_size, 112, 112, 64],
174                         'Conv2d_2_depthwise': [batch_size, 56, 56, 64],
175                         'Conv2d_2_pointwise': [batch_size, 56, 56, 128],
176                         'Conv2d_3_depthwise': [batch_size, 56, 56, 128],
177                         'Conv2d_3_pointwise': [batch_size, 56, 56, 128],
178                         'Conv2d_4_depthwise': [batch_size, 28, 28, 128],
179                         'Conv2d_4_pointwise': [batch_size, 28, 28, 256],
180                         'Conv2d_5_depthwise': [batch_size, 28, 28, 256],
181                         'Conv2d_5_pointwise': [batch_size, 28, 28, 256],
182                         'Conv2d_6_depthwise': [batch_size, 14, 14, 256],
183                         'Conv2d_6_pointwise': [batch_size, 14, 14, 512],
184                         'Conv2d_7_depthwise': [batch_size, 14, 14, 512],
185                         'Conv2d_7_pointwise': [batch_size, 14, 14, 512],
186                         'Conv2d_8_depthwise': [batch_size, 14, 14, 512],
187                         'Conv2d_8_pointwise': [batch_size, 14, 14, 512],
188                         'Conv2d_9_depthwise': [batch_size, 14, 14, 512],
189                         'Conv2d_9_pointwise': [batch_size, 14, 14, 512],
190                         'Conv2d_10_depthwise': [batch_size, 14, 14, 512],
191                         'Conv2d_10_pointwise': [batch_size, 14, 14, 512],
192                         'Conv2d_11_depthwise': [batch_size, 14, 14, 512],
193                         'Conv2d_11_pointwise': [batch_size, 14, 14, 512],
194                         'Conv2d_12_depthwise': [batch_size, 14, 14, 512],
195                         'Conv2d_12_pointwise': [batch_size, 14, 14, 1024],
196                         'Conv2d_13_depthwise': [batch_size, 14, 14, 1024],
197                         'Conv2d_13_pointwise': [batch_size, 14, 14, 1024]}
198     self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
199     for endpoint_name, expected_shape in endpoints_shapes.iteritems():
200       self.assertTrue(endpoint_name in end_points)
201       self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
202                            expected_shape)
203
204   def testOutputStride8BuildAndCheckAllEndPointsUptoConv2d_13(self):
205     batch_size = 5
206     height, width = 224, 224
207     output_stride = 8
208
209     inputs = tf.random_uniform((batch_size, height, width, 3))
210     with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
211                         normalizer_fn=slim.batch_norm):
212       _, end_points = mobilenet_v1.mobilenet_v1_base(
213           inputs, output_stride=output_stride,
214           final_endpoint='Conv2d_13_pointwise')
215     endpoints_shapes = {'Conv2d_0': [batch_size, 112, 112, 32],
216                         'Conv2d_1_depthwise': [batch_size, 112, 112, 32],
217                         'Conv2d_1_pointwise': [batch_size, 112, 112, 64],
218                         'Conv2d_2_depthwise': [batch_size, 56, 56, 64],
219                         'Conv2d_2_pointwise': [batch_size, 56, 56, 128],
220                         'Conv2d_3_depthwise': [batch_size, 56, 56, 128],
221                         'Conv2d_3_pointwise': [batch_size, 56, 56, 128],
222                         'Conv2d_4_depthwise': [batch_size, 28, 28, 128],
223                         'Conv2d_4_pointwise': [batch_size, 28, 28, 256],
224                         'Conv2d_5_depthwise': [batch_size, 28, 28, 256],
225                         'Conv2d_5_pointwise': [batch_size, 28, 28, 256],
226                         'Conv2d_6_depthwise': [batch_size, 28, 28, 256],
227                         'Conv2d_6_pointwise': [batch_size, 28, 28, 512],
228                         'Conv2d_7_depthwise': [batch_size, 28, 28, 512],
229                         'Conv2d_7_pointwise': [batch_size, 28, 28, 512],
230                         'Conv2d_8_depthwise': [batch_size, 28, 28, 512],
231                         'Conv2d_8_pointwise': [batch_size, 28, 28, 512],
232                         'Conv2d_9_depthwise': [batch_size, 28, 28, 512],
233                         'Conv2d_9_pointwise': [batch_size, 28, 28, 512],
234                         'Conv2d_10_depthwise': [batch_size, 28, 28, 512],
235                         'Conv2d_10_pointwise': [batch_size, 28, 28, 512],
236                         'Conv2d_11_depthwise': [batch_size, 28, 28, 512],
237                         'Conv2d_11_pointwise': [batch_size, 28, 28, 512],
238                         'Conv2d_12_depthwise': [batch_size, 28, 28, 512],
239                         'Conv2d_12_pointwise': [batch_size, 28, 28, 1024],
240                         'Conv2d_13_depthwise': [batch_size, 28, 28, 1024],
241                         'Conv2d_13_pointwise': [batch_size, 28, 28, 1024]}
242     self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
243     for endpoint_name, expected_shape in endpoints_shapes.iteritems():
244       self.assertTrue(endpoint_name in end_points)
245       self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
246                            expected_shape)
247
248   def testBuildAndCheckAllEndPointsApproximateFaceNet(self):
249     batch_size = 5
250     height, width = 128, 128
251
252     inputs = tf.random_uniform((batch_size, height, width, 3))
253     with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
254                         normalizer_fn=slim.batch_norm):
255       _, end_points = mobilenet_v1.mobilenet_v1_base(
256           inputs, final_endpoint='Conv2d_13_pointwise', depth_multiplier=0.75)
257     # For the Conv2d_0 layer FaceNet has depth=16
258     endpoints_shapes = {'Conv2d_0': [batch_size, 64, 64, 24],
259                         'Conv2d_1_depthwise': [batch_size, 64, 64, 24],
260                         'Conv2d_1_pointwise': [batch_size, 64, 64, 48],
261                         'Conv2d_2_depthwise': [batch_size, 32, 32, 48],
262                         'Conv2d_2_pointwise': [batch_size, 32, 32, 96],
263                         'Conv2d_3_depthwise': [batch_size, 32, 32, 96],
264                         'Conv2d_3_pointwise': [batch_size, 32, 32, 96],
265                         'Conv2d_4_depthwise': [batch_size, 16, 16, 96],
266                         'Conv2d_4_pointwise': [batch_size, 16, 16, 192],
267                         'Conv2d_5_depthwise': [batch_size, 16, 16, 192],
268                         'Conv2d_5_pointwise': [batch_size, 16, 16, 192],
269                         'Conv2d_6_depthwise': [batch_size, 8, 8, 192],
270                         'Conv2d_6_pointwise': [batch_size, 8, 8, 384],
271                         'Conv2d_7_depthwise': [batch_size, 8, 8, 384],
272                         'Conv2d_7_pointwise': [batch_size, 8, 8, 384],
273                         'Conv2d_8_depthwise': [batch_size, 8, 8, 384],
274                         'Conv2d_8_pointwise': [batch_size, 8, 8, 384],
275                         'Conv2d_9_depthwise': [batch_size, 8, 8, 384],
276                         'Conv2d_9_pointwise': [batch_size, 8, 8, 384],
277                         'Conv2d_10_depthwise': [batch_size, 8, 8, 384],
278                         'Conv2d_10_pointwise': [batch_size, 8, 8, 384],
279                         'Conv2d_11_depthwise': [batch_size, 8, 8, 384],
280                         'Conv2d_11_pointwise': [batch_size, 8, 8, 384],
281                         'Conv2d_12_depthwise': [batch_size, 4, 4, 384],
282                         'Conv2d_12_pointwise': [batch_size, 4, 4, 768],
283                         'Conv2d_13_depthwise': [batch_size, 4, 4, 768],
284                         'Conv2d_13_pointwise': [batch_size, 4, 4, 768]}
285     self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
286     for endpoint_name, expected_shape in endpoints_shapes.iteritems():
287       self.assertTrue(endpoint_name in end_points)
288       self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
289                            expected_shape)
290
291   def testModelHasExpectedNumberOfParameters(self):
292     batch_size = 5
293     height, width = 224, 224
294     inputs = tf.random_uniform((batch_size, height, width, 3))
295     with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
296                         normalizer_fn=slim.batch_norm):
297       mobilenet_v1.mobilenet_v1_base(inputs)
298       total_params, _ = slim.model_analyzer.analyze_vars(
299           slim.get_model_variables())
300       self.assertAlmostEqual(3217920L, total_params)
301
302   def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
303     batch_size = 5
304     height, width = 224, 224
305     num_classes = 1000
306
307     inputs = tf.random_uniform((batch_size, height, width, 3))
308     _, end_points = mobilenet_v1.mobilenet_v1(inputs, num_classes)
309
310     endpoint_keys = [key for key in end_points.keys() if key.startswith('Conv')]
311
312     _, end_points_with_multiplier = mobilenet_v1.mobilenet_v1(
313         inputs, num_classes, scope='depth_multiplied_net',
314         depth_multiplier=0.5)
315
316     for key in endpoint_keys:
317       original_depth = end_points[key].get_shape().as_list()[3]
318       new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
319       self.assertEqual(0.5 * original_depth, new_depth)
320
321   def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
322     batch_size = 5
323     height, width = 224, 224
324     num_classes = 1000
325
326     inputs = tf.random_uniform((batch_size, height, width, 3))
327     _, end_points = mobilenet_v1.mobilenet_v1(inputs, num_classes)
328
329     endpoint_keys = [key for key in end_points.keys()
330                      if key.startswith('Mixed') or key.startswith('Conv')]
331
332     _, end_points_with_multiplier = mobilenet_v1.mobilenet_v1(
333         inputs, num_classes, scope='depth_multiplied_net',
334         depth_multiplier=2.0)
335
336     for key in endpoint_keys:
337       original_depth = end_points[key].get_shape().as_list()[3]
338       new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
339       self.assertEqual(2.0 * original_depth, new_depth)
340
341   def testRaiseValueErrorWithInvalidDepthMultiplier(self):
342     batch_size = 5
343     height, width = 224, 224
344     num_classes = 1000
345
346     inputs = tf.random_uniform((batch_size, height, width, 3))
347     with self.assertRaises(ValueError):
348       _ = mobilenet_v1.mobilenet_v1(
349           inputs, num_classes, depth_multiplier=-0.1)
350     with self.assertRaises(ValueError):
351       _ = mobilenet_v1.mobilenet_v1(
352           inputs, num_classes, depth_multiplier=0.0)
353
354   def testHalfSizeImages(self):
355     batch_size = 5
356     height, width = 112, 112
357     num_classes = 1000
358
359     inputs = tf.random_uniform((batch_size, height, width, 3))
360     logits, end_points = mobilenet_v1.mobilenet_v1(inputs, num_classes)
361     self.assertTrue(logits.op.name.startswith('MobilenetV1/Logits'))
362     self.assertListEqual(logits.get_shape().as_list(),
363                          [batch_size, num_classes])
364     pre_pool = end_points['Conv2d_13_pointwise']
365     self.assertListEqual(pre_pool.get_shape().as_list(),
366                          [batch_size, 4, 4, 1024])
367
368   def testUnknownImageShape(self):
369     tf.reset_default_graph()
370     batch_size = 2
371     height, width = 224, 224
372     num_classes = 1000
373     input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
374     with self.test_session() as sess:
375       inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
376       logits, end_points = mobilenet_v1.mobilenet_v1(inputs, num_classes)
377       self.assertTrue(logits.op.name.startswith('MobilenetV1/Logits'))
378       self.assertListEqual(logits.get_shape().as_list(),
379                            [batch_size, num_classes])
380       pre_pool = end_points['Conv2d_13_pointwise']
381       feed_dict = {inputs: input_np}
382       tf.global_variables_initializer().run()
383       pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
384       self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
385
386   def testUnknowBatchSize(self):
387     batch_size = 1
388     height, width = 224, 224
389     num_classes = 1000
390
391     inputs = tf.placeholder(tf.float32, (None, height, width, 3))
392     logits, _ = mobilenet_v1.mobilenet_v1(inputs, num_classes)
393     self.assertTrue(logits.op.name.startswith('MobilenetV1/Logits'))
394     self.assertListEqual(logits.get_shape().as_list(),
395                          [None, num_classes])
396     images = tf.random_uniform((batch_size, height, width, 3))
397
398     with self.test_session() as sess:
399       sess.run(tf.global_variables_initializer())
400       output = sess.run(logits, {inputs: images.eval()})
401       self.assertEquals(output.shape, (batch_size, num_classes))
402
403   def testEvaluation(self):
404     batch_size = 2
405     height, width = 224, 224
406     num_classes = 1000
407
408     eval_inputs = tf.random_uniform((batch_size, height, width, 3))
409     logits, _ = mobilenet_v1.mobilenet_v1(eval_inputs, num_classes,
410                                           is_training=False)
411     predictions = tf.argmax(logits, 1)
412
413     with self.test_session() as sess:
414       sess.run(tf.global_variables_initializer())
415       output = sess.run(predictions)
416       self.assertEquals(output.shape, (batch_size,))
417
418   def testTrainEvalWithReuse(self):
419     train_batch_size = 5
420     eval_batch_size = 2
421     height, width = 150, 150
422     num_classes = 1000
423
424     train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
425     mobilenet_v1.mobilenet_v1(train_inputs, num_classes)
426     eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
427     logits, _ = mobilenet_v1.mobilenet_v1(eval_inputs, num_classes,
428                                           reuse=True)
429     predictions = tf.argmax(logits, 1)
430
431     with self.test_session() as sess:
432       sess.run(tf.global_variables_initializer())
433       output = sess.run(predictions)
434       self.assertEquals(output.shape, (eval_batch_size,))
435
436   def testLogitsNotSqueezed(self):
437     num_classes = 25
438     images = tf.random_uniform([1, 224, 224, 3])
439     logits, _ = mobilenet_v1.mobilenet_v1(images,
440                                           num_classes=num_classes,
441                                           spatial_squeeze=False)
442
443     with self.test_session() as sess:
444       tf.global_variables_initializer().run()
445       logits_out = sess.run(logits)
446       self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
447
448
449 if __name__ == '__main__':
450   tf.test.main()