1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains building blocks for various versions of Residual Networks.
17 Residual networks (ResNets) were proposed in:
18 Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
21 More variants were introduced in:
22 Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
23 Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
25 We can obtain different ResNet variants by changing the network depth, width,
26 and form of residual unit. This module implements the infrastructure for
27 building them. Concrete ResNet units and full ResNet networks are implemented in
28 the accompanying resnet_v1.py and resnet_v2.py modules.
30 Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
31 implementation we subsample the output activations in the last residual unit of
32 each block, instead of subsampling the input activations in the first residual
33 unit of each block. The two implementations give identical results but our
34 implementation is more memory efficient.
36 from __future__ import absolute_import
37 from __future__ import division
38 from __future__ import print_function
41 import tensorflow as tf
43 slim = tf.contrib.slim
46 class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
47 """A named tuple describing a ResNet block.
50 scope: The scope of the `Block`.
51 unit_fn: The ResNet unit function which takes as input a `Tensor` and
52 returns another `Tensor` with the output of the ResNet unit.
53 args: A list of length equal to the number of units in the `Block`. The list
54 contains one (depth, depth_bottleneck, stride) tuple for each unit in the
55 block to serve as argument to unit_fn.
59 def subsample(inputs, factor, scope=None):
60 """Subsamples the input along the spatial dimensions.
63 inputs: A `Tensor` of size [batch, height_in, width_in, channels].
64 factor: The subsampling factor.
65 scope: Optional variable_scope.
68 output: A `Tensor` of size [batch, height_out, width_out, channels] with the
69 input, either intact (if factor == 1) or subsampled (if factor > 1).
74 return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
77 def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
78 """Strided 2-D convolution with 'SAME' padding.
80 When stride > 1, then we do explicit zero-padding, followed by conv2d with
85 net = conv2d_same(inputs, num_outputs, 3, stride=stride)
89 net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
90 net = subsample(net, factor=stride)
94 net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
96 is different when the input's height or width is even, which is why we add the
97 current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
100 inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
101 num_outputs: An integer, the number of output filters.
102 kernel_size: An int with the kernel_size of the filters.
103 stride: An integer, the output stride.
104 rate: An integer, rate for atrous convolution.
108 output: A 4-D tensor of size [batch, height_out, width_out, channels] with
109 the convolution output.
112 return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
113 padding='SAME', scope=scope)
115 kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
116 pad_total = kernel_size_effective - 1
117 pad_beg = pad_total // 2
118 pad_end = pad_total - pad_beg
119 inputs = tf.pad(inputs,
120 [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
121 return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
122 rate=rate, padding='VALID', scope=scope)
126 def stack_blocks_dense(net, blocks, output_stride=None,
127 outputs_collections=None):
128 """Stacks ResNet `Blocks` and controls output feature density.
130 First, this function creates scopes for the ResNet in the form of
131 'block_name/unit_1', 'block_name/unit_2', etc.
133 Second, this function allows the user to explicitly control the ResNet
134 output_stride, which is the ratio of the input to output spatial resolution.
135 This is useful for dense prediction tasks such as semantic segmentation or
138 Most ResNets consist of 4 ResNet blocks and subsample the activations by a
139 factor of 2 when transitioning between consecutive ResNet blocks. This results
140 to a nominal ResNet output_stride equal to 8. If we set the output_stride to
141 half the nominal network stride (e.g., output_stride=4), then we compute
144 Control of the output feature density is implemented by atrous convolution.
147 net: A `Tensor` of size [batch, height, width, channels].
148 blocks: A list of length equal to the number of ResNet `Blocks`. Each
149 element is a ResNet `Block` object describing the units in the `Block`.
150 output_stride: If `None`, then the output will be computed at the nominal
151 network stride. If output_stride is not `None`, it specifies the requested
152 ratio of input to output spatial resolution, which needs to be equal to
153 the product of unit strides from the start up to some level of the ResNet.
154 For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
155 then valid values for the output_stride are 1, 2, 6, 24 or None (which
156 is equivalent to output_stride=24).
157 outputs_collections: Collection to add the ResNet block outputs.
160 net: Output tensor with stride equal to the specified output_stride.
163 ValueError: If the target output_stride is not valid.
165 # The current_stride variable keeps track of the effective stride of the
166 # activations. This allows us to invoke atrous convolution whenever applying
167 # the next residual unit would result in the activations having stride larger
168 # than the target output_stride.
171 # The atrous convolution rate parameter.
175 with tf.variable_scope(block.scope, 'block', [net]) as sc:
176 for i, unit in enumerate(block.args):
177 if output_stride is not None and current_stride > output_stride:
178 raise ValueError('The target output_stride cannot be reached.')
180 with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
181 # If we have reached the target output_stride, then we need to employ
182 # atrous convolution with stride=1 and multiply the atrous rate by the
183 # current unit's stride for use in subsequent layers.
184 if output_stride is not None and current_stride == output_stride:
185 net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
186 rate *= unit.get('stride', 1)
189 net = block.unit_fn(net, rate=1, **unit)
190 current_stride *= unit.get('stride', 1)
191 net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
193 if output_stride is not None and current_stride != output_stride:
194 raise ValueError('The target output_stride cannot be reached.')
199 def resnet_arg_scope(weight_decay=0.0001,
200 batch_norm_decay=0.997, #0.997
201 batch_norm_epsilon=1e-5,
202 batch_norm_scale=True):
203 """Defines the default ResNet arg scope.
205 TODO(gpapan): The batch-normalization related default values above are
206 appropriate for use in conjunction with the reference ResNet models
207 released at https://github.com/KaimingHe/deep-residual-networks. When
208 training ResNets from scratch, they might need to be tuned.
211 weight_decay: The weight decay to use for regularizing the model.
212 batch_norm_decay: The moving average decay when estimating layer activation
213 statistics in batch normalization.
214 batch_norm_epsilon: Small constant to prevent division by zero when
215 normalizing activations by their variance in batch normalization.
216 batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
217 activations in the batch normalization layer.
220 An `arg_scope` to use for the resnet models.
222 batch_norm_params = {
223 'decay': batch_norm_decay,
224 'epsilon': batch_norm_epsilon,
225 'scale': batch_norm_scale,
226 'updates_collections': tf.GraphKeys.UPDATE_OPS,
231 weights_regularizer=slim.l2_regularizer(weight_decay),
232 weights_initializer=slim.variance_scaling_initializer(),
233 activation_fn=tf.nn.relu,
234 normalizer_fn=slim.batch_norm,
235 normalizer_params=batch_norm_params):
236 with slim.arg_scope([slim.batch_norm], **batch_norm_params):
237 # The following implies padding='SAME' for pool1, which makes feature
238 # alignment easier for dense prediction tasks. This is also used in
239 # https://github.com/facebook/fb.resnet.torch. However the accompanying
240 # code of 'Deep Residual Learning for Image Recognition' uses
241 # padding='VALID' for pool1. You can switch to that choice by setting
242 # slim.arg_scope([slim.max_pool2d], padding='VALID').
243 with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: