7617701574dae04460832945d3ffa25872111f4e
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / slim_nets / resnet_v2.py
1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains definitions for the preactivation form of Residual Networks.
16
17 Residual networks (ResNets) were originally proposed in:
18 [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19     Deep Residual Learning for Image Recognition. arXiv:1512.03385
20
21 The full preactivation 'v2' ResNet variant implemented in this module was
22 introduced by:
23 [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24     Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
25
26 The key difference of the full preactivation 'v2' variant compared to the
27 'v1' variant in [1] is the use of batch normalization before every weight layer.
28
29 Typical use:
30
31    from tensorflow.contrib.slim.slim_nets import resnet_v2
32
33 ResNet-101 for image classification into 1000 classes:
34
35    # inputs has shape [batch, 224, 224, 3]
36    with slim.arg_scope(resnet_v2.resnet_arg_scope()):
37       net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False)
38
39 ResNet-101 for semantic segmentation into 21 classes:
40
41    # inputs has shape [batch, 513, 513, 3]
42    with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)):
43       net, end_points = resnet_v2.resnet_v2_101(inputs,
44                                                 21,
45                                                 is_training=False,
46                                                 global_pool=False,
47                                                 output_stride=16)
48 """
49 from __future__ import absolute_import
50 from __future__ import division
51 from __future__ import print_function
52
53 import tensorflow as tf
54
55 from nets import resnet_utils
56
57 slim = tf.contrib.slim
58 resnet_arg_scope = resnet_utils.resnet_arg_scope
59
60
61 @slim.add_arg_scope
62 def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
63                outputs_collections=None, scope=None):
64   """Bottleneck residual unit variant with BN before convolutions.
65
66   This is the full preactivation residual unit variant proposed in [2]. See
67   Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck
68   variant which has an extra bottleneck layer.
69
70   When putting together two consecutive ResNet blocks that use this unit, one
71   should use stride = 2 in the last unit of the first block.
72
73   Args:
74     inputs: A tensor of size [batch, height, width, channels].
75     depth: The depth of the ResNet unit output.
76     depth_bottleneck: The depth of the bottleneck layers.
77     stride: The ResNet unit's stride. Determines the amount of downsampling of
78       the units output compared to its input.
79     rate: An integer, rate for atrous convolution.
80     outputs_collections: Collection to add the ResNet unit output.
81     scope: Optional variable_scope.
82
83   Returns:
84     The ResNet unit's output.
85   """
86   with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
87     depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
88     preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact')
89     if depth == depth_in:
90       shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
91     else:
92       shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride,
93                              normalizer_fn=None, activation_fn=None,
94                              scope='shortcut')
95
96     residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1,
97                            scope='conv1')
98     residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
99                                         rate=rate, scope='conv2')
100     residual = slim.conv2d(residual, depth, [1, 1], stride=1,
101                            normalizer_fn=None, activation_fn=None,
102                            scope='conv3')
103
104     output = shortcut + residual
105
106     return slim.utils.collect_named_outputs(outputs_collections,
107                                             sc.original_name_scope,
108                                             output)
109
110
111 def resnet_v2(inputs,
112               blocks,
113               num_classes=None,
114               is_training=True,
115               global_pool=True,
116               output_stride=None,
117               include_root_block=True,
118               spatial_squeeze=False,
119               reuse=None,
120               scope=None):
121   """Generator for v2 (preactivation) ResNet models.
122
123   This function generates a family of ResNet v2 models. See the resnet_v2_*()
124   methods for specific model instantiations, obtained by selecting different
125   block instantiations that produce ResNets of various depths.
126
127   Training for image classification on Imagenet is usually done with [224, 224]
128   inputs, resulting in [7, 7] feature maps at the output of the last ResNet
129   block for the ResNets defined in [1] that have nominal stride equal to 32.
130   However, for dense prediction tasks we advise that one uses inputs with
131   spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
132   this case the feature maps at the ResNet output will have spatial shape
133   [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
134   and corners exactly aligned with the input image corners, which greatly
135   facilitates alignment of the features to the image. Using as input [225, 225]
136   images results in [8, 8] feature maps at the output of the last ResNet block.
137
138   For dense prediction tasks, the ResNet needs to run in fully-convolutional
139   (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
140   have nominal stride equal to 32 and a good choice in FCN mode is to use
141   output_stride=16 in order to increase the density of the computed features at
142   small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
143
144   Args:
145     inputs: A tensor of size [batch, height_in, width_in, channels].
146     blocks: A list of length equal to the number of ResNet blocks. Each element
147       is a resnet_utils.Block object describing the units in the block.
148     num_classes: Number of predicted classes for classification tasks. If None
149       we return the features before the logit layer.
150     is_training: whether is training or not.
151     global_pool: If True, we perform global average pooling before computing the
152       logits. Set to True for image classification, False for dense prediction.
153     output_stride: If None, then the output will be computed at the nominal
154       network stride. If output_stride is not None, it specifies the requested
155       ratio of input to output spatial resolution.
156     include_root_block: If True, include the initial convolution followed by
157       max-pooling, if False excludes it. If excluded, `inputs` should be the
158       results of an activation-less convolution.
159     spatial_squeeze: if True, logits is of shape [B, C], if false logits is
160         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
161     reuse: whether or not the network and its variables should be reused. To be
162       able to reuse 'scope' must be given.
163     scope: Optional variable_scope.
164
165
166   Returns:
167     net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
168       If global_pool is False, then height_out and width_out are reduced by a
169       factor of output_stride compared to the respective height_in and width_in,
170       else both height_out and width_out equal one. If num_classes is None, then
171       net is the output of the last ResNet block, potentially after global
172       average pooling. If num_classes is not None, net contains the pre-softmax
173       activations.
174     end_points: A dictionary from components of the network to the corresponding
175       activation.
176
177   Raises:
178     ValueError: If the target output_stride is not valid.
179   """
180   with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc:
181     end_points_collection = sc.name + '_end_points'
182     with slim.arg_scope([slim.conv2d, bottleneck,
183                          resnet_utils.stack_blocks_dense],
184                         outputs_collections=end_points_collection):
185       with slim.arg_scope([slim.batch_norm], is_training=is_training):
186         net = inputs
187         if include_root_block:
188           if output_stride is not None:
189             if output_stride % 4 != 0:
190               raise ValueError('The output_stride needs to be a multiple of 4.')
191             output_stride /= 4
192           # We do not include batch normalization or activation functions in
193           # conv1 because the first ResNet unit will perform these. Cf.
194           # Appendix of [2].
195           with slim.arg_scope([slim.conv2d],
196                               activation_fn=None, normalizer_fn=None):
197             net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
198           net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
199         net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
200         # This is needed because the pre-activation variant does not have batch
201         # normalization or activation functions in the residual unit output. See
202         # Appendix of [2].
203         net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm')
204         if global_pool:
205           # Global average pooling.
206           net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
207         if num_classes is not None:
208           net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
209                             normalizer_fn=None, scope='logits')
210         if spatial_squeeze:
211           logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
212         else:
213           logits = net
214         # Convert end_points_collection into a dictionary of end_points.
215         end_points = slim.utils.convert_collection_to_dict(
216             end_points_collection)
217         if num_classes is not None:
218           end_points['predictions'] = slim.softmax(logits, scope='predictions')
219         return logits, end_points
220 resnet_v2.default_image_size = 224
221
222
223 def resnet_v2_block(scope, base_depth, num_units, stride):
224   """Helper function for creating a resnet_v2 bottleneck block.
225
226   Args:
227     scope: The scope of the block.
228     base_depth: The depth of the bottleneck layer for each unit.
229     num_units: The number of units in the block.
230     stride: The stride of the block, implemented as a stride in the last unit.
231       All other units have stride=1.
232
233   Returns:
234     A resnet_v2 bottleneck block.
235   """
236   return resnet_utils.Block(scope, bottleneck, [{
237       'depth': base_depth * 4,
238       'depth_bottleneck': base_depth,
239       'stride': 1
240   }] * (num_units - 1) + [{
241       'depth': base_depth * 4,
242       'depth_bottleneck': base_depth,
243       'stride': stride
244   }])
245 resnet_v2.default_image_size = 224
246
247
248 def resnet_v2_50(inputs,
249                  num_classes=None,
250                  is_training=True,
251                  global_pool=True,
252                  output_stride=None,
253                  spatial_squeeze=False,
254                  reuse=None,
255                  scope='resnet_v2_50'):
256   """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
257   blocks = [
258       resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
259       resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
260       resnet_v2_block('block3', base_depth=256, num_units=6, stride=2),
261       resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
262   ]
263   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
264                    global_pool=global_pool, output_stride=output_stride,
265                    include_root_block=True, spatial_squeeze=spatial_squeeze,
266                    reuse=reuse, scope=scope)
267 resnet_v2_50.default_image_size = resnet_v2.default_image_size
268
269
270 def resnet_v2_101(inputs,
271                   num_classes=None,
272                   is_training=True,
273                   global_pool=True,
274                   output_stride=None,
275                   spatial_squeeze=False,
276                   reuse=None,
277                   scope='resnet_v2_101'):
278   """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
279   blocks = [
280       resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
281       resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
282       resnet_v2_block('block3', base_depth=256, num_units=23, stride=2),
283       resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
284   ]
285   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
286                    global_pool=global_pool, output_stride=output_stride,
287                    include_root_block=True, spatial_squeeze=spatial_squeeze,
288                    reuse=reuse, scope=scope)
289 resnet_v2_101.default_image_size = resnet_v2.default_image_size
290
291
292 def resnet_v2_152(inputs,
293                   num_classes=None,
294                   is_training=True,
295                   global_pool=True,
296                   output_stride=None,
297                   spatial_squeeze=False,
298                   reuse=None,
299                   scope='resnet_v2_152'):
300   """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
301   blocks = [
302       resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
303       resnet_v2_block('block2', base_depth=128, num_units=8, stride=2),
304       resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
305       resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
306   ]
307   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
308                    global_pool=global_pool, output_stride=output_stride,
309                    include_root_block=True, spatial_squeeze=spatial_squeeze,
310                    reuse=reuse, scope=scope)
311 resnet_v2_152.default_image_size = resnet_v2.default_image_size
312
313
314 def resnet_v2_200(inputs,
315                   num_classes=None,
316                   is_training=True,
317                   global_pool=True,
318                   output_stride=None,
319                   spatial_squeeze=False,
320                   reuse=None,
321                   scope='resnet_v2_200'):
322   """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
323   blocks = [
324       resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
325       resnet_v2_block('block2', base_depth=128, num_units=24, stride=2),
326       resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
327       resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
328   ]
329   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
330                    global_pool=global_pool, output_stride=output_stride,
331                    include_root_block=True, spatial_squeeze=spatial_squeeze,
332                    reuse=reuse, scope=scope)
333 resnet_v2_200.default_image_size = resnet_v2.default_image_size