EG version upgrade to 1.3
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / slim_nets / inception_resnet_v2.py
1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains the definition of the Inception Resnet V2 architecture.
16
17 As described in http://arxiv.org/abs/1602.07261.
18
19   Inception-v4, Inception-ResNet and the Impact of Residual Connections
20     on Learning
21   Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
22 """
23 from __future__ import absolute_import
24 from __future__ import division
25 from __future__ import print_function
26
27
28 import tensorflow as tf
29
30 slim = tf.contrib.slim
31
32
33 def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
34   """Builds the 35x35 resnet block."""
35   with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
36     with tf.variable_scope('Branch_0'):
37       tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
38     with tf.variable_scope('Branch_1'):
39       tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
40       tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
41     with tf.variable_scope('Branch_2'):
42       tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
43       tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3')
44       tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3')
45     mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_1, tower_conv2_2])
46     up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
47                      activation_fn=None, scope='Conv2d_1x1')
48     net += scale * up
49     if activation_fn:
50       net = activation_fn(net)
51   return net
52
53
54 def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
55   """Builds the 17x17 resnet block."""
56   with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
57     with tf.variable_scope('Branch_0'):
58       tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
59     with tf.variable_scope('Branch_1'):
60       tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
61       tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7],
62                                   scope='Conv2d_0b_1x7')
63       tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1],
64                                   scope='Conv2d_0c_7x1')
65     mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
66     up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
67                      activation_fn=None, scope='Conv2d_1x1')
68     net += scale * up
69     if activation_fn:
70       net = activation_fn(net)
71   return net
72
73
74 def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
75   """Builds the 8x8 resnet block."""
76   with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
77     with tf.variable_scope('Branch_0'):
78       tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
79     with tf.variable_scope('Branch_1'):
80       tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
81       tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3],
82                                   scope='Conv2d_0b_1x3')
83       tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1],
84                                   scope='Conv2d_0c_3x1')
85     mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
86     up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
87                      activation_fn=None, scope='Conv2d_1x1')
88     net += scale * up
89     if activation_fn:
90       net = activation_fn(net)
91   return net
92
93
94 def inception_resnet_v2_base(inputs,
95                              final_endpoint='Conv2d_7b_1x1',
96                              output_stride=16,
97                              align_feature_maps=False,
98                              scope=None):
99   """Inception model from  http://arxiv.org/abs/1602.07261.
100
101   Constructs an Inception Resnet v2 network from inputs to the given final
102   endpoint. This method can construct the network up to the final inception
103   block Conv2d_7b_1x1.
104
105   Args:
106     inputs: a tensor of size [batch_size, height, width, channels].
107     final_endpoint: specifies the endpoint to construct the network up to. It
108       can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
109       'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3',
110       'Mixed_5b', 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1']
111     output_stride: A scalar that specifies the requested ratio of input to
112       output spatial resolution. Only supports 8 and 16.
113     align_feature_maps: When true, changes all the VALID paddings in the network
114       to SAME padding so that the feature maps are aligned.
115     scope: Optional variable_scope.
116
117   Returns:
118     tensor_out: output tensor corresponding to the final_endpoint.
119     end_points: a set of activations for external use, for example summaries or
120                 losses.
121
122   Raises:
123     ValueError: if final_endpoint is not set to one of the predefined values,
124       or if the output_stride is not 8 or 16, or if the output_stride is 8 and
125       we request an end point after 'PreAuxLogits'.
126   """
127   if output_stride != 8 and output_stride != 16:
128     raise ValueError('output_stride must be 8 or 16.')
129
130   padding = 'SAME' if align_feature_maps else 'VALID'
131
132   end_points = {}
133
134   def add_and_check_final(name, net):
135     end_points[name] = net
136     return name == final_endpoint
137
138   with tf.variable_scope(scope, 'InceptionResnetV2', [inputs]):
139     with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
140                         stride=1, padding='SAME'):
141       # 149 x 149 x 32
142       net = slim.conv2d(inputs, 32, 3, stride=2, padding=padding,
143                         scope='Conv2d_1a_3x3')
144       if add_and_check_final('Conv2d_1a_3x3', net): return net, end_points
145
146       # 147 x 147 x 32
147       net = slim.conv2d(net, 32, 3, padding=padding,
148                         scope='Conv2d_2a_3x3')
149       if add_and_check_final('Conv2d_2a_3x3', net): return net, end_points
150       # 147 x 147 x 64
151       net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
152       if add_and_check_final('Conv2d_2b_3x3', net): return net, end_points
153       # 73 x 73 x 64
154       net = slim.max_pool2d(net, 3, stride=2, padding=padding,
155                             scope='MaxPool_3a_3x3')
156       if add_and_check_final('MaxPool_3a_3x3', net): return net, end_points
157       # 73 x 73 x 80
158       net = slim.conv2d(net, 80, 1, padding=padding,
159                         scope='Conv2d_3b_1x1')
160       if add_and_check_final('Conv2d_3b_1x1', net): return net, end_points
161       # 71 x 71 x 192
162       net = slim.conv2d(net, 192, 3, padding=padding,
163                         scope='Conv2d_4a_3x3')
164       if add_and_check_final('Conv2d_4a_3x3', net): return net, end_points
165       # 35 x 35 x 192
166       net = slim.max_pool2d(net, 3, stride=2, padding=padding,
167                             scope='MaxPool_5a_3x3')
168       if add_and_check_final('MaxPool_5a_3x3', net): return net, end_points
169
170       # 35 x 35 x 320
171       with tf.variable_scope('Mixed_5b'):
172         with tf.variable_scope('Branch_0'):
173           tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
174         with tf.variable_scope('Branch_1'):
175           tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
176           tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
177                                       scope='Conv2d_0b_5x5')
178         with tf.variable_scope('Branch_2'):
179           tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
180           tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
181                                       scope='Conv2d_0b_3x3')
182           tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
183                                       scope='Conv2d_0c_3x3')
184         with tf.variable_scope('Branch_3'):
185           tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
186                                        scope='AvgPool_0a_3x3')
187           tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
188                                      scope='Conv2d_0b_1x1')
189         net = tf.concat(
190             [tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1], 3)
191
192       if add_and_check_final('Mixed_5b', net): return net, end_points
193       # TODO(alemi): Register intermediate endpoints
194       net = slim.repeat(net, 10, block35, scale=0.17)
195
196       # 17 x 17 x 1088 if output_stride == 8,
197       # 33 x 33 x 1088 if output_stride == 16
198       use_atrous = output_stride == 8
199
200       with tf.variable_scope('Mixed_6a'):
201         with tf.variable_scope('Branch_0'):
202           tower_conv = slim.conv2d(net, 384, 3, stride=1 if use_atrous else 2,
203                                    padding=padding,
204                                    scope='Conv2d_1a_3x3')
205         with tf.variable_scope('Branch_1'):
206           tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
207           tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
208                                       scope='Conv2d_0b_3x3')
209           tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
210                                       stride=1 if use_atrous else 2,
211                                       padding=padding,
212                                       scope='Conv2d_1a_3x3')
213         with tf.variable_scope('Branch_2'):
214           tower_pool = slim.max_pool2d(net, 3, stride=1 if use_atrous else 2,
215                                        padding=padding,
216                                        scope='MaxPool_1a_3x3')
217         net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)
218
219       if add_and_check_final('Mixed_6a', net): return net, end_points
220
221       # TODO(alemi): register intermediate endpoints
222       with slim.arg_scope([slim.conv2d], rate=2 if use_atrous else 1):
223         net = slim.repeat(net, 20, block17, scale=0.10)
224       if add_and_check_final('PreAuxLogits', net): return net, end_points
225
226       if output_stride == 8:
227         # TODO(gpapan): Properly support output_stride for the rest of the net.
228         raise ValueError('output_stride==8 is only supported up to the '
229                          'PreAuxlogits end_point for now.')
230
231       # 8 x 8 x 2080
232       with tf.variable_scope('Mixed_7a'):
233         with tf.variable_scope('Branch_0'):
234           tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
235           tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
236                                      padding=padding,
237                                      scope='Conv2d_1a_3x3')
238         with tf.variable_scope('Branch_1'):
239           tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
240           tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
241                                       padding=padding,
242                                       scope='Conv2d_1a_3x3')
243         with tf.variable_scope('Branch_2'):
244           tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
245           tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
246                                       scope='Conv2d_0b_3x3')
247           tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
248                                       padding=padding,
249                                       scope='Conv2d_1a_3x3')
250         with tf.variable_scope('Branch_3'):
251           tower_pool = slim.max_pool2d(net, 3, stride=2,
252                                        padding=padding,
253                                        scope='MaxPool_1a_3x3')
254         net = tf.concat(
255             [tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool], 3)
256
257       if add_and_check_final('Mixed_7a', net): return net, end_points
258
259       # TODO(alemi): register intermediate endpoints
260       net = slim.repeat(net, 9, block8, scale=0.20)
261       net = block8(net, activation_fn=None)
262
263       # 8 x 8 x 1536
264       net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
265       if add_and_check_final('Conv2d_7b_1x1', net): return net, end_points
266
267     raise ValueError('final_endpoint (%s) not recognized', final_endpoint)
268
269
270 def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
271                         dropout_keep_prob=0.8,#0.8
272                         reuse=None,
273                         scope='InceptionResnetV2',
274                         create_aux_logits=True):
275   """Creates the Inception Resnet V2 model.
276
277   Args:
278     inputs: a 4-D tensor of size [batch_size, height, width, 3].
279     num_classes: number of predicted classes.
280     is_training: whether is training or not.
281     dropout_keep_prob: float, the fraction to keep before final layer.
282     reuse: whether or not the network and its variables should be reused. To be
283       able to reuse 'scope' must be given.
284     scope: Optional variable_scope.
285     create_aux_logits: Whether to include the auxilliary logits.
286
287   Returns:
288     logits: the logits outputs of the model.
289     end_points: the set of end_points from the inception model.
290   """
291   end_points = {}
292
293   with tf.variable_scope(scope, 'InceptionResnetV2', [inputs, num_classes],
294                          reuse=reuse) as scope:
295     with slim.arg_scope([slim.batch_norm, slim.dropout],
296                         is_training=is_training):
297
298       net, end_points = inception_resnet_v2_base(inputs, scope=scope)
299
300       if create_aux_logits:
301         with tf.variable_scope('AuxLogits'):
302           aux = end_points['PreAuxLogits']
303           aux = slim.avg_pool2d(aux, 5, stride=3, padding='VALID',
304                                 scope='Conv2d_1a_3x3')
305           aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')
306           aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],
307                             padding='VALID', scope='Conv2d_2a_5x5')
308           aux = slim.flatten(aux)
309           aux = slim.fully_connected(aux, num_classes, activation_fn=None,
310                                      scope='Logits')
311           end_points['AuxLogits'] = aux
312
313       with tf.variable_scope('Logits'):
314         net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
315                               scope='AvgPool_1a_8x8')
316         net = slim.flatten(net)
317
318         net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
319                            scope='Dropout')
320
321         end_points['PreLogitsFlatten'] = net
322         # end_points['yjr_feature'] = tf.squeeze(net, axis=0)
323
324         logits = slim.fully_connected(net, num_classes, activation_fn=None,
325                                       scope='Logits')
326         end_points['Logits'] = logits
327         end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
328
329     return logits, end_points
330 inception_resnet_v2.default_image_size = 299
331
332
333 def inception_resnet_v2_arg_scope(weight_decay=0.00004,
334                                   batch_norm_decay=0.9997,
335                                   batch_norm_epsilon=0.001):
336   """Yields the scope with the default parameters for inception_resnet_v2.
337
338   Args:
339     weight_decay: the weight decay for weights variables.
340     batch_norm_decay: decay for the moving average of batch_norm momentums.
341     batch_norm_epsilon: small float added to variance to avoid dividing by zero.
342
343   Returns:
344     a arg_scope with the parameters needed for inception_resnet_v2.
345   """
346   # Set weight_decay for weights in conv2d and fully_connected layers.
347   with slim.arg_scope([slim.conv2d, slim.fully_connected],
348                       weights_regularizer=slim.l2_regularizer(weight_decay),
349                       biases_regularizer=slim.l2_regularizer(weight_decay)):
350
351     batch_norm_params = {
352         'decay': batch_norm_decay,
353         'epsilon': batch_norm_epsilon,
354     }
355     # Set activation_fn and parameters for batch_norm.
356     with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu,
357                         normalizer_fn=slim.batch_norm,
358                         normalizer_params=batch_norm_params) as scope:
359       return scope