1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Tests for slim_nets.inception_v1."""
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
22 import tensorflow as tf
24 from nets import inception
26 slim = tf.contrib.slim
29 class InceptionV3Test(tf.test.TestCase):
31 def testBuildClassificationNetwork(self):
33 height, width = 299, 299
36 inputs = tf.random_uniform((batch_size, height, width, 3))
37 logits, end_points = inception.inception_v3(inputs, num_classes)
38 self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
39 self.assertListEqual(logits.get_shape().as_list(),
40 [batch_size, num_classes])
41 self.assertTrue('Predictions' in end_points)
42 self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
43 [batch_size, num_classes])
45 def testBuildBaseNetwork(self):
47 height, width = 299, 299
49 inputs = tf.random_uniform((batch_size, height, width, 3))
50 final_endpoint, end_points = inception.inception_v3_base(inputs)
51 self.assertTrue(final_endpoint.op.name.startswith(
52 'InceptionV3/Mixed_7c'))
53 self.assertListEqual(final_endpoint.get_shape().as_list(),
54 [batch_size, 8, 8, 2048])
55 expected_endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
56 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
57 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
58 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
59 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
60 self.assertItemsEqual(end_points.keys(), expected_endpoints)
62 def testBuildOnlyUptoFinalEndpoint(self):
64 height, width = 299, 299
65 endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
66 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
67 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
68 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
69 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
71 for index, endpoint in enumerate(endpoints):
72 with tf.Graph().as_default():
73 inputs = tf.random_uniform((batch_size, height, width, 3))
74 out_tensor, end_points = inception.inception_v3_base(
75 inputs, final_endpoint=endpoint)
76 self.assertTrue(out_tensor.op.name.startswith(
77 'InceptionV3/' + endpoint))
78 self.assertItemsEqual(endpoints[:index+1], end_points)
80 def testBuildAndCheckAllEndPointsUptoMixed7c(self):
82 height, width = 299, 299
84 inputs = tf.random_uniform((batch_size, height, width, 3))
85 _, end_points = inception.inception_v3_base(
86 inputs, final_endpoint='Mixed_7c')
87 endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32],
88 'Conv2d_2a_3x3': [batch_size, 147, 147, 32],
89 'Conv2d_2b_3x3': [batch_size, 147, 147, 64],
90 'MaxPool_3a_3x3': [batch_size, 73, 73, 64],
91 'Conv2d_3b_1x1': [batch_size, 73, 73, 80],
92 'Conv2d_4a_3x3': [batch_size, 71, 71, 192],
93 'MaxPool_5a_3x3': [batch_size, 35, 35, 192],
94 'Mixed_5b': [batch_size, 35, 35, 256],
95 'Mixed_5c': [batch_size, 35, 35, 288],
96 'Mixed_5d': [batch_size, 35, 35, 288],
97 'Mixed_6a': [batch_size, 17, 17, 768],
98 'Mixed_6b': [batch_size, 17, 17, 768],
99 'Mixed_6c': [batch_size, 17, 17, 768],
100 'Mixed_6d': [batch_size, 17, 17, 768],
101 'Mixed_6e': [batch_size, 17, 17, 768],
102 'Mixed_7a': [batch_size, 8, 8, 1280],
103 'Mixed_7b': [batch_size, 8, 8, 2048],
104 'Mixed_7c': [batch_size, 8, 8, 2048]}
105 self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
106 for endpoint_name in endpoints_shapes:
107 expected_shape = endpoints_shapes[endpoint_name]
108 self.assertTrue(endpoint_name in end_points)
109 self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
112 def testModelHasExpectedNumberOfParameters(self):
114 height, width = 299, 299
115 inputs = tf.random_uniform((batch_size, height, width, 3))
116 with slim.arg_scope(inception.inception_v3_arg_scope()):
117 inception.inception_v3_base(inputs)
118 total_params, _ = slim.model_analyzer.analyze_vars(
119 slim.get_model_variables())
120 self.assertAlmostEqual(21802784, total_params)
122 def testBuildEndPoints(self):
124 height, width = 299, 299
127 inputs = tf.random_uniform((batch_size, height, width, 3))
128 _, end_points = inception.inception_v3(inputs, num_classes)
129 self.assertTrue('Logits' in end_points)
130 logits = end_points['Logits']
131 self.assertListEqual(logits.get_shape().as_list(),
132 [batch_size, num_classes])
133 self.assertTrue('AuxLogits' in end_points)
134 aux_logits = end_points['AuxLogits']
135 self.assertListEqual(aux_logits.get_shape().as_list(),
136 [batch_size, num_classes])
137 self.assertTrue('Mixed_7c' in end_points)
138 pre_pool = end_points['Mixed_7c']
139 self.assertListEqual(pre_pool.get_shape().as_list(),
140 [batch_size, 8, 8, 2048])
141 self.assertTrue('PreLogits' in end_points)
142 pre_logits = end_points['PreLogits']
143 self.assertListEqual(pre_logits.get_shape().as_list(),
144 [batch_size, 1, 1, 2048])
146 def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
148 height, width = 299, 299
151 inputs = tf.random_uniform((batch_size, height, width, 3))
152 _, end_points = inception.inception_v3(inputs, num_classes)
154 endpoint_keys = [key for key in end_points.keys()
155 if key.startswith('Mixed') or key.startswith('Conv')]
157 _, end_points_with_multiplier = inception.inception_v3(
158 inputs, num_classes, scope='depth_multiplied_net',
159 depth_multiplier=0.5)
161 for key in endpoint_keys:
162 original_depth = end_points[key].get_shape().as_list()[3]
163 new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
164 self.assertEqual(0.5 * original_depth, new_depth)
166 def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
168 height, width = 299, 299
171 inputs = tf.random_uniform((batch_size, height, width, 3))
172 _, end_points = inception.inception_v3(inputs, num_classes)
174 endpoint_keys = [key for key in end_points.keys()
175 if key.startswith('Mixed') or key.startswith('Conv')]
177 _, end_points_with_multiplier = inception.inception_v3(
178 inputs, num_classes, scope='depth_multiplied_net',
179 depth_multiplier=2.0)
181 for key in endpoint_keys:
182 original_depth = end_points[key].get_shape().as_list()[3]
183 new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
184 self.assertEqual(2.0 * original_depth, new_depth)
186 def testRaiseValueErrorWithInvalidDepthMultiplier(self):
188 height, width = 299, 299
191 inputs = tf.random_uniform((batch_size, height, width, 3))
192 with self.assertRaises(ValueError):
193 _ = inception.inception_v3(inputs, num_classes, depth_multiplier=-0.1)
194 with self.assertRaises(ValueError):
195 _ = inception.inception_v3(inputs, num_classes, depth_multiplier=0.0)
197 def testHalfSizeImages(self):
199 height, width = 150, 150
202 inputs = tf.random_uniform((batch_size, height, width, 3))
203 logits, end_points = inception.inception_v3(inputs, num_classes)
204 self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
205 self.assertListEqual(logits.get_shape().as_list(),
206 [batch_size, num_classes])
207 pre_pool = end_points['Mixed_7c']
208 self.assertListEqual(pre_pool.get_shape().as_list(),
209 [batch_size, 3, 3, 2048])
211 def testUnknownImageShape(self):
212 tf.reset_default_graph()
214 height, width = 299, 299
216 input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
217 with self.test_session() as sess:
218 inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
219 logits, end_points = inception.inception_v3(inputs, num_classes)
220 self.assertListEqual(logits.get_shape().as_list(),
221 [batch_size, num_classes])
222 pre_pool = end_points['Mixed_7c']
223 feed_dict = {inputs: input_np}
224 tf.global_variables_initializer().run()
225 pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
226 self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048])
228 def testUnknowBatchSize(self):
230 height, width = 299, 299
233 inputs = tf.placeholder(tf.float32, (None, height, width, 3))
234 logits, _ = inception.inception_v3(inputs, num_classes)
235 self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
236 self.assertListEqual(logits.get_shape().as_list(),
238 images = tf.random_uniform((batch_size, height, width, 3))
240 with self.test_session() as sess:
241 sess.run(tf.global_variables_initializer())
242 output = sess.run(logits, {inputs: images.eval()})
243 self.assertEquals(output.shape, (batch_size, num_classes))
245 def testEvaluation(self):
247 height, width = 299, 299
250 eval_inputs = tf.random_uniform((batch_size, height, width, 3))
251 logits, _ = inception.inception_v3(eval_inputs, num_classes,
253 predictions = tf.argmax(logits, 1)
255 with self.test_session() as sess:
256 sess.run(tf.global_variables_initializer())
257 output = sess.run(predictions)
258 self.assertEquals(output.shape, (batch_size,))
260 def testTrainEvalWithReuse(self):
263 height, width = 150, 150
266 train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
267 inception.inception_v3(train_inputs, num_classes)
268 eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
269 logits, _ = inception.inception_v3(eval_inputs, num_classes,
270 is_training=False, reuse=True)
271 predictions = tf.argmax(logits, 1)
273 with self.test_session() as sess:
274 sess.run(tf.global_variables_initializer())
275 output = sess.run(predictions)
276 self.assertEquals(output.shape, (eval_batch_size,))
278 def testLogitsNotSqueezed(self):
280 images = tf.random_uniform([1, 299, 299, 3])
281 logits, _ = inception.inception_v3(images,
282 num_classes=num_classes,
283 spatial_squeeze=False)
285 with self.test_session() as sess:
286 tf.global_variables_initializer().run()
287 logits_out = sess.run(logits)
288 self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
291 if __name__ == '__main__':