EG version upgrade to 1.3
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / slim_nets / inception_resnet_v2_test.py
1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Tests for slim.inception_resnet_v2."""
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19
20 import tensorflow as tf
21
22 from nets import inception
23
24
25 class InceptionTest(tf.test.TestCase):
26
27   def testBuildLogits(self):
28     batch_size = 5
29     height, width = 299, 299
30     num_classes = 1000
31     with self.test_session():
32       inputs = tf.random_uniform((batch_size, height, width, 3))
33       logits, endpoints = inception.inception_resnet_v2(inputs, num_classes)
34       self.assertTrue('AuxLogits' in endpoints)
35       auxlogits = endpoints['AuxLogits']
36       self.assertTrue(
37           auxlogits.op.name.startswith('InceptionResnetV2/AuxLogits'))
38       self.assertListEqual(auxlogits.get_shape().as_list(),
39                            [batch_size, num_classes])
40       self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
41       self.assertListEqual(logits.get_shape().as_list(),
42                            [batch_size, num_classes])
43
44   def testBuildWithoutAuxLogits(self):
45     batch_size = 5
46     height, width = 299, 299
47     num_classes = 1000
48     with self.test_session():
49       inputs = tf.random_uniform((batch_size, height, width, 3))
50       logits, endpoints = inception.inception_resnet_v2(inputs, num_classes,
51                                                         create_aux_logits=False)
52       self.assertTrue('AuxLogits' not in endpoints)
53       self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
54       self.assertListEqual(logits.get_shape().as_list(),
55                            [batch_size, num_classes])
56
57   def testBuildEndPoints(self):
58     batch_size = 5
59     height, width = 299, 299
60     num_classes = 1000
61     with self.test_session():
62       inputs = tf.random_uniform((batch_size, height, width, 3))
63       _, end_points = inception.inception_resnet_v2(inputs, num_classes)
64       self.assertTrue('Logits' in end_points)
65       logits = end_points['Logits']
66       self.assertListEqual(logits.get_shape().as_list(),
67                            [batch_size, num_classes])
68       self.assertTrue('AuxLogits' in end_points)
69       aux_logits = end_points['AuxLogits']
70       self.assertListEqual(aux_logits.get_shape().as_list(),
71                            [batch_size, num_classes])
72       pre_pool = end_points['Conv2d_7b_1x1']
73       self.assertListEqual(pre_pool.get_shape().as_list(),
74                            [batch_size, 8, 8, 1536])
75
76   def testBuildBaseNetwork(self):
77     batch_size = 5
78     height, width = 299, 299
79
80     inputs = tf.random_uniform((batch_size, height, width, 3))
81     net, end_points = inception.inception_resnet_v2_base(inputs)
82     self.assertTrue(net.op.name.startswith('InceptionResnetV2/Conv2d_7b_1x1'))
83     self.assertListEqual(net.get_shape().as_list(),
84                          [batch_size, 8, 8, 1536])
85     expected_endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
86                           'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
87                           'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_6a',
88                           'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1']
89     self.assertItemsEqual(end_points.keys(), expected_endpoints)
90
91   def testBuildOnlyUptoFinalEndpoint(self):
92     batch_size = 5
93     height, width = 299, 299
94     endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
95                  'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
96                  'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_6a',
97                  'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1']
98     for index, endpoint in enumerate(endpoints):
99       with tf.Graph().as_default():
100         inputs = tf.random_uniform((batch_size, height, width, 3))
101         out_tensor, end_points = inception.inception_resnet_v2_base(
102             inputs, final_endpoint=endpoint)
103         if endpoint != 'PreAuxLogits':
104           self.assertTrue(out_tensor.op.name.startswith(
105               'InceptionResnetV2/' + endpoint))
106         self.assertItemsEqual(endpoints[:index+1], end_points)
107
108   def testBuildAndCheckAllEndPointsUptoPreAuxLogits(self):
109     batch_size = 5
110     height, width = 299, 299
111
112     inputs = tf.random_uniform((batch_size, height, width, 3))
113     _, end_points = inception.inception_resnet_v2_base(
114         inputs, final_endpoint='PreAuxLogits')
115     endpoints_shapes = {'Conv2d_1a_3x3': [5, 149, 149, 32],
116                         'Conv2d_2a_3x3': [5, 147, 147, 32],
117                         'Conv2d_2b_3x3': [5, 147, 147, 64],
118                         'MaxPool_3a_3x3': [5, 73, 73, 64],
119                         'Conv2d_3b_1x1': [5, 73, 73, 80],
120                         'Conv2d_4a_3x3': [5, 71, 71, 192],
121                         'MaxPool_5a_3x3': [5, 35, 35, 192],
122                         'Mixed_5b': [5, 35, 35, 320],
123                         'Mixed_6a': [5, 17, 17, 1088],
124                         'PreAuxLogits': [5, 17, 17, 1088]
125                        }
126
127     self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
128     for endpoint_name in endpoints_shapes:
129       expected_shape = endpoints_shapes[endpoint_name]
130       self.assertTrue(endpoint_name in end_points)
131       self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
132                            expected_shape)
133
134   def testBuildAndCheckAllEndPointsUptoPreAuxLogitsWithAlignedFeatureMaps(self):
135     batch_size = 5
136     height, width = 299, 299
137
138     inputs = tf.random_uniform((batch_size, height, width, 3))
139     _, end_points = inception.inception_resnet_v2_base(
140         inputs, final_endpoint='PreAuxLogits', align_feature_maps=True)
141     endpoints_shapes = {'Conv2d_1a_3x3': [5, 150, 150, 32],
142                         'Conv2d_2a_3x3': [5, 150, 150, 32],
143                         'Conv2d_2b_3x3': [5, 150, 150, 64],
144                         'MaxPool_3a_3x3': [5, 75, 75, 64],
145                         'Conv2d_3b_1x1': [5, 75, 75, 80],
146                         'Conv2d_4a_3x3': [5, 75, 75, 192],
147                         'MaxPool_5a_3x3': [5, 38, 38, 192],
148                         'Mixed_5b': [5, 38, 38, 320],
149                         'Mixed_6a': [5, 19, 19, 1088],
150                         'PreAuxLogits': [5, 19, 19, 1088]
151                        }
152
153     self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
154     for endpoint_name in endpoints_shapes:
155       expected_shape = endpoints_shapes[endpoint_name]
156       self.assertTrue(endpoint_name in end_points)
157       self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
158                            expected_shape)
159
160   def testBuildAndCheckAllEndPointsUptoPreAuxLogitsWithOutputStrideEight(self):
161     batch_size = 5
162     height, width = 299, 299
163
164     inputs = tf.random_uniform((batch_size, height, width, 3))
165     _, end_points = inception.inception_resnet_v2_base(
166         inputs, final_endpoint='PreAuxLogits', output_stride=8)
167     endpoints_shapes = {'Conv2d_1a_3x3': [5, 149, 149, 32],
168                         'Conv2d_2a_3x3': [5, 147, 147, 32],
169                         'Conv2d_2b_3x3': [5, 147, 147, 64],
170                         'MaxPool_3a_3x3': [5, 73, 73, 64],
171                         'Conv2d_3b_1x1': [5, 73, 73, 80],
172                         'Conv2d_4a_3x3': [5, 71, 71, 192],
173                         'MaxPool_5a_3x3': [5, 35, 35, 192],
174                         'Mixed_5b': [5, 35, 35, 320],
175                         'Mixed_6a': [5, 33, 33, 1088],
176                         'PreAuxLogits': [5, 33, 33, 1088]
177                        }
178
179     self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
180     for endpoint_name in endpoints_shapes:
181       expected_shape = endpoints_shapes[endpoint_name]
182       self.assertTrue(endpoint_name in end_points)
183       self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
184                            expected_shape)
185
186   def testVariablesSetDevice(self):
187     batch_size = 5
188     height, width = 299, 299
189     num_classes = 1000
190     with self.test_session():
191       inputs = tf.random_uniform((batch_size, height, width, 3))
192       # Force all Variables to reside on the device.
193       with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
194         inception.inception_resnet_v2(inputs, num_classes)
195       with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
196         inception.inception_resnet_v2(inputs, num_classes)
197       for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
198         self.assertDeviceEqual(v.device, '/cpu:0')
199       for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
200         self.assertDeviceEqual(v.device, '/gpu:0')
201
202   def testHalfSizeImages(self):
203     batch_size = 5
204     height, width = 150, 150
205     num_classes = 1000
206     with self.test_session():
207       inputs = tf.random_uniform((batch_size, height, width, 3))
208       logits, end_points = inception.inception_resnet_v2(inputs, num_classes)
209       self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
210       self.assertListEqual(logits.get_shape().as_list(),
211                            [batch_size, num_classes])
212       pre_pool = end_points['Conv2d_7b_1x1']
213       self.assertListEqual(pre_pool.get_shape().as_list(),
214                            [batch_size, 3, 3, 1536])
215
216   def testUnknownBatchSize(self):
217     batch_size = 1
218     height, width = 299, 299
219     num_classes = 1000
220     with self.test_session() as sess:
221       inputs = tf.placeholder(tf.float32, (None, height, width, 3))
222       logits, _ = inception.inception_resnet_v2(inputs, num_classes)
223       self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
224       self.assertListEqual(logits.get_shape().as_list(),
225                            [None, num_classes])
226       images = tf.random_uniform((batch_size, height, width, 3))
227       sess.run(tf.global_variables_initializer())
228       output = sess.run(logits, {inputs: images.eval()})
229       self.assertEquals(output.shape, (batch_size, num_classes))
230
231   def testEvaluation(self):
232     batch_size = 2
233     height, width = 299, 299
234     num_classes = 1000
235     with self.test_session() as sess:
236       eval_inputs = tf.random_uniform((batch_size, height, width, 3))
237       logits, _ = inception.inception_resnet_v2(eval_inputs,
238                                                 num_classes,
239                                                 is_training=False)
240       predictions = tf.argmax(logits, 1)
241       sess.run(tf.global_variables_initializer())
242       output = sess.run(predictions)
243       self.assertEquals(output.shape, (batch_size,))
244
245   def testTrainEvalWithReuse(self):
246     train_batch_size = 5
247     eval_batch_size = 2
248     height, width = 150, 150
249     num_classes = 1000
250     with self.test_session() as sess:
251       train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
252       inception.inception_resnet_v2(train_inputs, num_classes)
253       eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
254       logits, _ = inception.inception_resnet_v2(eval_inputs,
255                                                 num_classes,
256                                                 is_training=False,
257                                                 reuse=True)
258       predictions = tf.argmax(logits, 1)
259       sess.run(tf.global_variables_initializer())
260       output = sess.run(predictions)
261       self.assertEquals(output.shape, (eval_batch_size,))
262
263
264 if __name__ == '__main__':
265   tf.test.main()