1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains definitions for the preactivation form of Residual Networks.
17 Residual networks (ResNets) were originally proposed in:
18 [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 Deep Residual Learning for Image Recognition. arXiv:1512.03385
21 The full preactivation 'v2' ResNet variant implemented in this module was
23 [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24 Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
26 The key difference of the full preactivation 'v2' variant compared to the
27 'v1' variant in [1] is the use of batch normalization before every weight layer.
31 from tensorflow.contrib.slim.slim_nets import resnet_v2
33 ResNet-101 for image classification into 1000 classes:
35 # inputs has shape [batch, 224, 224, 3]
36 with slim.arg_scope(resnet_v2.resnet_arg_scope()):
37 net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False)
39 ResNet-101 for semantic segmentation into 21 classes:
41 # inputs has shape [batch, 513, 513, 3]
42 with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)):
43 net, end_points = resnet_v2.resnet_v2_101(inputs,
49 from __future__ import absolute_import
50 from __future__ import division
51 from __future__ import print_function
53 import tensorflow as tf
55 from nets import resnet_utils
57 slim = tf.contrib.slim
58 resnet_arg_scope = resnet_utils.resnet_arg_scope
62 def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
63 outputs_collections=None, scope=None):
64 """Bottleneck residual unit variant with BN before convolutions.
66 This is the full preactivation residual unit variant proposed in [2]. See
67 Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck
68 variant which has an extra bottleneck layer.
70 When putting together two consecutive ResNet blocks that use this unit, one
71 should use stride = 2 in the last unit of the first block.
74 inputs: A tensor of size [batch, height, width, channels].
75 depth: The depth of the ResNet unit output.
76 depth_bottleneck: The depth of the bottleneck layers.
77 stride: The ResNet unit's stride. Determines the amount of downsampling of
78 the units output compared to its input.
79 rate: An integer, rate for atrous convolution.
80 outputs_collections: Collection to add the ResNet unit output.
81 scope: Optional variable_scope.
84 The ResNet unit's output.
86 with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
87 depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
88 preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact')
90 shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
92 shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride,
93 normalizer_fn=None, activation_fn=None,
96 residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1,
98 residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
99 rate=rate, scope='conv2')
100 residual = slim.conv2d(residual, depth, [1, 1], stride=1,
101 normalizer_fn=None, activation_fn=None,
104 output = shortcut + residual
106 return slim.utils.collect_named_outputs(outputs_collections,
107 sc.original_name_scope,
111 def resnet_v2(inputs,
117 include_root_block=True,
118 spatial_squeeze=False,
121 """Generator for v2 (preactivation) ResNet models.
123 This function generates a family of ResNet v2 models. See the resnet_v2_*()
124 methods for specific model instantiations, obtained by selecting different
125 block instantiations that produce ResNets of various depths.
127 Training for image classification on Imagenet is usually done with [224, 224]
128 inputs, resulting in [7, 7] feature maps at the output of the last ResNet
129 block for the ResNets defined in [1] that have nominal stride equal to 32.
130 However, for dense prediction tasks we advise that one uses inputs with
131 spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
132 this case the feature maps at the ResNet output will have spatial shape
133 [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
134 and corners exactly aligned with the input image corners, which greatly
135 facilitates alignment of the features to the image. Using as input [225, 225]
136 images results in [8, 8] feature maps at the output of the last ResNet block.
138 For dense prediction tasks, the ResNet needs to run in fully-convolutional
139 (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
140 have nominal stride equal to 32 and a good choice in FCN mode is to use
141 output_stride=16 in order to increase the density of the computed features at
142 small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
145 inputs: A tensor of size [batch, height_in, width_in, channels].
146 blocks: A list of length equal to the number of ResNet blocks. Each element
147 is a resnet_utils.Block object describing the units in the block.
148 num_classes: Number of predicted classes for classification tasks. If None
149 we return the features before the logit layer.
150 is_training: whether is training or not.
151 global_pool: If True, we perform global average pooling before computing the
152 logits. Set to True for image classification, False for dense prediction.
153 output_stride: If None, then the output will be computed at the nominal
154 network stride. If output_stride is not None, it specifies the requested
155 ratio of input to output spatial resolution.
156 include_root_block: If True, include the initial convolution followed by
157 max-pooling, if False excludes it. If excluded, `inputs` should be the
158 results of an activation-less convolution.
159 spatial_squeeze: if True, logits is of shape [B, C], if false logits is
160 of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
161 reuse: whether or not the network and its variables should be reused. To be
162 able to reuse 'scope' must be given.
163 scope: Optional variable_scope.
167 net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
168 If global_pool is False, then height_out and width_out are reduced by a
169 factor of output_stride compared to the respective height_in and width_in,
170 else both height_out and width_out equal one. If num_classes is None, then
171 net is the output of the last ResNet block, potentially after global
172 average pooling. If num_classes is not None, net contains the pre-softmax
174 end_points: A dictionary from components of the network to the corresponding
178 ValueError: If the target output_stride is not valid.
180 with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc:
181 end_points_collection = sc.name + '_end_points'
182 with slim.arg_scope([slim.conv2d, bottleneck,
183 resnet_utils.stack_blocks_dense],
184 outputs_collections=end_points_collection):
185 with slim.arg_scope([slim.batch_norm], is_training=is_training):
187 if include_root_block:
188 if output_stride is not None:
189 if output_stride % 4 != 0:
190 raise ValueError('The output_stride needs to be a multiple of 4.')
192 # We do not include batch normalization or activation functions in
193 # conv1 because the first ResNet unit will perform these. Cf.
195 with slim.arg_scope([slim.conv2d],
196 activation_fn=None, normalizer_fn=None):
197 net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
198 net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
199 net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
200 # This is needed because the pre-activation variant does not have batch
201 # normalization or activation functions in the residual unit output. See
203 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm')
205 # Global average pooling.
206 net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
207 if num_classes is not None:
208 net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
209 normalizer_fn=None, scope='logits')
211 logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
214 # Convert end_points_collection into a dictionary of end_points.
215 end_points = slim.utils.convert_collection_to_dict(
216 end_points_collection)
217 if num_classes is not None:
218 end_points['predictions'] = slim.softmax(logits, scope='predictions')
219 return logits, end_points
220 resnet_v2.default_image_size = 224
223 def resnet_v2_block(scope, base_depth, num_units, stride):
224 """Helper function for creating a resnet_v2 bottleneck block.
227 scope: The scope of the block.
228 base_depth: The depth of the bottleneck layer for each unit.
229 num_units: The number of units in the block.
230 stride: The stride of the block, implemented as a stride in the last unit.
231 All other units have stride=1.
234 A resnet_v2 bottleneck block.
236 return resnet_utils.Block(scope, bottleneck, [{
237 'depth': base_depth * 4,
238 'depth_bottleneck': base_depth,
240 }] * (num_units - 1) + [{
241 'depth': base_depth * 4,
242 'depth_bottleneck': base_depth,
245 resnet_v2.default_image_size = 224
248 def resnet_v2_50(inputs,
253 spatial_squeeze=False,
255 scope='resnet_v2_50'):
256 """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
258 resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
259 resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
260 resnet_v2_block('block3', base_depth=256, num_units=6, stride=2),
261 resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
263 return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
264 global_pool=global_pool, output_stride=output_stride,
265 include_root_block=True, spatial_squeeze=spatial_squeeze,
266 reuse=reuse, scope=scope)
267 resnet_v2_50.default_image_size = resnet_v2.default_image_size
270 def resnet_v2_101(inputs,
275 spatial_squeeze=False,
277 scope='resnet_v2_101'):
278 """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
280 resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
281 resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
282 resnet_v2_block('block3', base_depth=256, num_units=23, stride=2),
283 resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
285 return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
286 global_pool=global_pool, output_stride=output_stride,
287 include_root_block=True, spatial_squeeze=spatial_squeeze,
288 reuse=reuse, scope=scope)
289 resnet_v2_101.default_image_size = resnet_v2.default_image_size
292 def resnet_v2_152(inputs,
297 spatial_squeeze=False,
299 scope='resnet_v2_152'):
300 """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
302 resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
303 resnet_v2_block('block2', base_depth=128, num_units=8, stride=2),
304 resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
305 resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
307 return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
308 global_pool=global_pool, output_stride=output_stride,
309 include_root_block=True, spatial_squeeze=spatial_squeeze,
310 reuse=reuse, scope=scope)
311 resnet_v2_152.default_image_size = resnet_v2.default_image_size
314 def resnet_v2_200(inputs,
319 spatial_squeeze=False,
321 scope='resnet_v2_200'):
322 """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
324 resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
325 resnet_v2_block('block2', base_depth=128, num_units=24, stride=2),
326 resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
327 resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
329 return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
330 global_pool=global_pool, output_stride=output_stride,
331 include_root_block=True, spatial_squeeze=spatial_squeeze,
332 reuse=reuse, scope=scope)
333 resnet_v2_200.default_image_size = resnet_v2.default_image_size