1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains the definition of the Inception V4 architecture.
17 As described in http://arxiv.org/abs/1602.07261.
19 Inception-v4, Inception-ResNet and the Impact of Residual Connections
21 Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
23 from __future__ import absolute_import
24 from __future__ import division
25 from __future__ import print_function
27 import tensorflow as tf
29 from nets import inception_utils
31 slim = tf.contrib.slim
34 def block_inception_a(inputs, scope=None, reuse=None):
35 """Builds Inception-A block for Inception v4 network."""
36 # By default use stride=1 and SAME padding
37 with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
38 stride=1, padding='SAME'):
39 with tf.variable_scope(scope, 'BlockInceptionA', [inputs], reuse=reuse):
40 with tf.variable_scope('Branch_0'):
41 branch_0 = slim.conv2d(inputs, 96, [1, 1], scope='Conv2d_0a_1x1')
42 with tf.variable_scope('Branch_1'):
43 branch_1 = slim.conv2d(inputs, 64, [1, 1], scope='Conv2d_0a_1x1')
44 branch_1 = slim.conv2d(branch_1, 96, [3, 3], scope='Conv2d_0b_3x3')
45 with tf.variable_scope('Branch_2'):
46 branch_2 = slim.conv2d(inputs, 64, [1, 1], scope='Conv2d_0a_1x1')
47 branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3')
48 branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0c_3x3')
49 with tf.variable_scope('Branch_3'):
50 branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3')
51 branch_3 = slim.conv2d(branch_3, 96, [1, 1], scope='Conv2d_0b_1x1')
52 return tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
55 def block_reduction_a(inputs, scope=None, reuse=None):
56 """Builds Reduction-A block for Inception v4 network."""
57 # By default use stride=1 and SAME padding
58 with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
59 stride=1, padding='SAME'):
60 with tf.variable_scope(scope, 'BlockReductionA', [inputs], reuse=reuse):
61 with tf.variable_scope('Branch_0'):
62 branch_0 = slim.conv2d(inputs, 384, [3, 3], stride=2, padding='VALID',
63 scope='Conv2d_1a_3x3')
64 with tf.variable_scope('Branch_1'):
65 branch_1 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1')
66 branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3')
67 branch_1 = slim.conv2d(branch_1, 256, [3, 3], stride=2,
68 padding='VALID', scope='Conv2d_1a_3x3')
69 with tf.variable_scope('Branch_2'):
70 branch_2 = slim.max_pool2d(inputs, [3, 3], stride=2, padding='VALID',
71 scope='MaxPool_1a_3x3')
72 return tf.concat(axis=3, values=[branch_0, branch_1, branch_2])
75 def block_inception_b(inputs, scope=None, reuse=None):
76 """Builds Inception-B block for Inception v4 network."""
77 # By default use stride=1 and SAME padding
78 with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
79 stride=1, padding='SAME'):
80 with tf.variable_scope(scope, 'BlockInceptionB', [inputs], reuse=reuse):
81 with tf.variable_scope('Branch_0'):
82 branch_0 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1')
83 with tf.variable_scope('Branch_1'):
84 branch_1 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1')
85 branch_1 = slim.conv2d(branch_1, 224, [1, 7], scope='Conv2d_0b_1x7')
86 branch_1 = slim.conv2d(branch_1, 256, [7, 1], scope='Conv2d_0c_7x1')
87 with tf.variable_scope('Branch_2'):
88 branch_2 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1')
89 branch_2 = slim.conv2d(branch_2, 192, [7, 1], scope='Conv2d_0b_7x1')
90 branch_2 = slim.conv2d(branch_2, 224, [1, 7], scope='Conv2d_0c_1x7')
91 branch_2 = slim.conv2d(branch_2, 224, [7, 1], scope='Conv2d_0d_7x1')
92 branch_2 = slim.conv2d(branch_2, 256, [1, 7], scope='Conv2d_0e_1x7')
93 with tf.variable_scope('Branch_3'):
94 branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3')
95 branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
96 return tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
99 def block_reduction_b(inputs, scope=None, reuse=None):
100 """Builds Reduction-B block for Inception v4 network."""
101 # By default use stride=1 and SAME padding
102 with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
103 stride=1, padding='SAME'):
104 with tf.variable_scope(scope, 'BlockReductionB', [inputs], reuse=reuse):
105 with tf.variable_scope('Branch_0'):
106 branch_0 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1')
107 branch_0 = slim.conv2d(branch_0, 192, [3, 3], stride=2,
108 padding='VALID', scope='Conv2d_1a_3x3')
109 with tf.variable_scope('Branch_1'):
110 branch_1 = slim.conv2d(inputs, 256, [1, 1], scope='Conv2d_0a_1x1')
111 branch_1 = slim.conv2d(branch_1, 256, [1, 7], scope='Conv2d_0b_1x7')
112 branch_1 = slim.conv2d(branch_1, 320, [7, 1], scope='Conv2d_0c_7x1')
113 branch_1 = slim.conv2d(branch_1, 320, [3, 3], stride=2,
114 padding='VALID', scope='Conv2d_1a_3x3')
115 with tf.variable_scope('Branch_2'):
116 branch_2 = slim.max_pool2d(inputs, [3, 3], stride=2, padding='VALID',
117 scope='MaxPool_1a_3x3')
118 return tf.concat(axis=3, values=[branch_0, branch_1, branch_2])
121 def block_inception_c(inputs, scope=None, reuse=None):
122 """Builds Inception-C block for Inception v4 network."""
123 # By default use stride=1 and SAME padding
124 with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
125 stride=1, padding='SAME'):
126 with tf.variable_scope(scope, 'BlockInceptionC', [inputs], reuse=reuse):
127 with tf.variable_scope('Branch_0'):
128 branch_0 = slim.conv2d(inputs, 256, [1, 1], scope='Conv2d_0a_1x1')
129 with tf.variable_scope('Branch_1'):
130 branch_1 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1')
131 branch_1 = tf.concat(axis=3, values=[
132 slim.conv2d(branch_1, 256, [1, 3], scope='Conv2d_0b_1x3'),
133 slim.conv2d(branch_1, 256, [3, 1], scope='Conv2d_0c_3x1')])
134 with tf.variable_scope('Branch_2'):
135 branch_2 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1')
136 branch_2 = slim.conv2d(branch_2, 448, [3, 1], scope='Conv2d_0b_3x1')
137 branch_2 = slim.conv2d(branch_2, 512, [1, 3], scope='Conv2d_0c_1x3')
138 branch_2 = tf.concat(axis=3, values=[
139 slim.conv2d(branch_2, 256, [1, 3], scope='Conv2d_0d_1x3'),
140 slim.conv2d(branch_2, 256, [3, 1], scope='Conv2d_0e_3x1')])
141 with tf.variable_scope('Branch_3'):
142 branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3')
143 branch_3 = slim.conv2d(branch_3, 256, [1, 1], scope='Conv2d_0b_1x1')
144 return tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
147 def inception_v4_base(inputs, final_endpoint='Mixed_7d', scope=None):
148 """Creates the Inception V4 network up to the given final endpoint.
151 inputs: a 4-D tensor of size [batch_size, height, width, 3].
152 final_endpoint: specifies the endpoint to construct the network up to.
153 It can be one of [ 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
154 'Mixed_3a', 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
155 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e',
156 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c',
158 scope: Optional variable_scope.
161 logits: the logits outputs of the model.
162 end_points: the set of end_points from the inception model.
165 ValueError: if final_endpoint is not set to one of the predefined values,
169 def add_and_check_final(name, net):
170 end_points[name] = net
171 return name == final_endpoint
173 with tf.variable_scope(scope, 'InceptionV4', [inputs]):
174 with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
175 stride=1, padding='SAME'):
177 net = slim.conv2d(inputs, 32, [3, 3], stride=2,
178 padding='VALID', scope='Conv2d_1a_3x3')
179 if add_and_check_final('Conv2d_1a_3x3', net): return net, end_points
181 net = slim.conv2d(net, 32, [3, 3], padding='VALID',
182 scope='Conv2d_2a_3x3')
183 if add_and_check_final('Conv2d_2a_3x3', net): return net, end_points
185 net = slim.conv2d(net, 64, [3, 3], scope='Conv2d_2b_3x3')
186 if add_and_check_final('Conv2d_2b_3x3', net): return net, end_points
188 with tf.variable_scope('Mixed_3a'):
189 with tf.variable_scope('Branch_0'):
190 branch_0 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID',
191 scope='MaxPool_0a_3x3')
192 with tf.variable_scope('Branch_1'):
193 branch_1 = slim.conv2d(net, 96, [3, 3], stride=2, padding='VALID',
194 scope='Conv2d_0a_3x3')
195 net = tf.concat(axis=3, values=[branch_0, branch_1])
196 if add_and_check_final('Mixed_3a', net): return net, end_points
199 with tf.variable_scope('Mixed_4a'):
200 with tf.variable_scope('Branch_0'):
201 branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
202 branch_0 = slim.conv2d(branch_0, 96, [3, 3], padding='VALID',
203 scope='Conv2d_1a_3x3')
204 with tf.variable_scope('Branch_1'):
205 branch_1 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
206 branch_1 = slim.conv2d(branch_1, 64, [1, 7], scope='Conv2d_0b_1x7')
207 branch_1 = slim.conv2d(branch_1, 64, [7, 1], scope='Conv2d_0c_7x1')
208 branch_1 = slim.conv2d(branch_1, 96, [3, 3], padding='VALID',
209 scope='Conv2d_1a_3x3')
210 net = tf.concat(axis=3, values=[branch_0, branch_1])
211 if add_and_check_final('Mixed_4a', net): return net, end_points
214 with tf.variable_scope('Mixed_5a'):
215 with tf.variable_scope('Branch_0'):
216 branch_0 = slim.conv2d(net, 192, [3, 3], stride=2, padding='VALID',
217 scope='Conv2d_1a_3x3')
218 with tf.variable_scope('Branch_1'):
219 branch_1 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID',
220 scope='MaxPool_1a_3x3')
221 net = tf.concat(axis=3, values=[branch_0, branch_1])
222 if add_and_check_final('Mixed_5a', net): return net, end_points
225 # 4 x Inception-A blocks
227 block_scope = 'Mixed_5' + chr(ord('b') + idx)
228 net = block_inception_a(net, block_scope)
229 if add_and_check_final(block_scope, net): return net, end_points
233 net = block_reduction_a(net, 'Mixed_6a')
234 if add_and_check_final('Mixed_6a', net): return net, end_points
237 # 7 x Inception-B blocks
239 block_scope = 'Mixed_6' + chr(ord('b') + idx)
240 net = block_inception_b(net, block_scope)
241 if add_and_check_final(block_scope, net): return net, end_points
245 net = block_reduction_b(net, 'Mixed_7a')
246 if add_and_check_final('Mixed_7a', net): return net, end_points
249 # 3 x Inception-C blocks
251 block_scope = 'Mixed_7' + chr(ord('b') + idx)
252 net = block_inception_c(net, block_scope)
253 if add_and_check_final(block_scope, net): return net, end_points
254 raise ValueError('Unknown final endpoint %s' % final_endpoint)
257 def inception_v4(inputs, num_classes=1001, is_training=True,
258 dropout_keep_prob=0.8,
261 create_aux_logits=True):
262 """Creates the Inception V4 model.
265 inputs: a 4-D tensor of size [batch_size, height, width, 3].
266 num_classes: number of predicted classes.
267 is_training: whether is training or not.
268 dropout_keep_prob: float, the fraction to keep before final layer.
269 reuse: whether or not the network and its variables should be reused. To be
270 able to reuse 'scope' must be given.
271 scope: Optional variable_scope.
272 create_aux_logits: Whether to include the auxiliary logits.
275 logits: the logits outputs of the model.
276 end_points: the set of end_points from the inception model.
279 with tf.variable_scope(scope, 'InceptionV4', [inputs], reuse=reuse) as scope:
280 with slim.arg_scope([slim.batch_norm, slim.dropout],
281 is_training=is_training):
282 net, end_points = inception_v4_base(inputs, scope=scope)
284 with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
285 stride=1, padding='SAME'):
286 # Auxiliary Head logits
287 if create_aux_logits:
288 with tf.variable_scope('AuxLogits'):
290 aux_logits = end_points['Mixed_6h']
291 aux_logits = slim.avg_pool2d(aux_logits, [5, 5], stride=3,
293 scope='AvgPool_1a_5x5')
294 aux_logits = slim.conv2d(aux_logits, 128, [1, 1],
295 scope='Conv2d_1b_1x1')
296 aux_logits = slim.conv2d(aux_logits, 768,
297 aux_logits.get_shape()[1:3],
298 padding='VALID', scope='Conv2d_2a')
299 aux_logits = slim.flatten(aux_logits)
300 aux_logits = slim.fully_connected(aux_logits, num_classes,
303 end_points['AuxLogits'] = aux_logits
305 # Final pooling and prediction
306 with tf.variable_scope('Logits'):
308 net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
311 net = slim.dropout(net, dropout_keep_prob, scope='Dropout_1b')
312 net = slim.flatten(net, scope='PreLogitsFlatten')
313 end_points['PreLogitsFlatten'] = net
315 logits = slim.fully_connected(net, num_classes, activation_fn=None,
317 end_points['Logits'] = logits
318 end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
319 return logits, end_points
320 inception_v4.default_image_size = 299
323 inception_v4_arg_scope = inception_utils.inception_arg_scope