EG version upgrade to 1.3
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / slim_nets / inception_v2.py
1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains the definition for inception v2 classification network."""
16
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20
21 import tensorflow as tf
22
23 from nets import inception_utils
24
25 slim = tf.contrib.slim
26 trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
27
28
29 def inception_v2_base(inputs,
30                       final_endpoint='Mixed_5c',
31                       min_depth=16,
32                       depth_multiplier=1.0,
33                       scope=None):
34   """Inception v2 (6a2).
35
36   Constructs an Inception v2 network from inputs to the given final endpoint.
37   This method can construct the network up to the layer inception(5b) as
38   described in http://arxiv.org/abs/1502.03167.
39
40   Args:
41     inputs: a tensor of shape [batch_size, height, width, channels].
42     final_endpoint: specifies the endpoint to construct the network up to. It
43       can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
44       'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a',
45       'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b',
46       'Mixed_5c'].
47     min_depth: Minimum depth value (number of channels) for all convolution ops.
48       Enforced when depth_multiplier < 1, and not an active constraint when
49       depth_multiplier >= 1.
50     depth_multiplier: Float multiplier for the depth (number of channels)
51       for all convolution ops. The value must be greater than zero. Typical
52       usage will be to set this value in (0, 1) to reduce the number of
53       parameters or computation cost of the model.
54     scope: Optional variable_scope.
55
56   Returns:
57     tensor_out: output tensor corresponding to the final_endpoint.
58     end_points: a set of activations for external use, for example summaries or
59                 losses.
60
61   Raises:
62     ValueError: if final_endpoint is not set to one of the predefined values,
63                 or depth_multiplier <= 0
64   """
65
66   # end_points will collect relevant activations for external use, for example
67   # summaries or losses.
68   end_points = {}
69
70   # Used to find thinned depths for each layer.
71   if depth_multiplier <= 0:
72     raise ValueError('depth_multiplier is not greater than zero.')
73   depth = lambda d: max(int(d * depth_multiplier), min_depth)
74
75   with tf.variable_scope(scope, 'InceptionV2', [inputs]):
76     with slim.arg_scope(
77         [slim.conv2d, slim.max_pool2d, slim.avg_pool2d, slim.separable_conv2d],
78         stride=1, padding='SAME'):
79
80       # Note that sizes in the comments below assume an input spatial size of
81       # 224x224, however, the inputs can be of any size greater 32x32.
82
83       # 224 x 224 x 3
84       end_point = 'Conv2d_1a_7x7'
85       # depthwise_multiplier here is different from depth_multiplier.
86       # depthwise_multiplier determines the output channels of the initial
87       # depthwise conv (see docs for tf.nn.separable_conv2d), while
88       # depth_multiplier controls the # channels of the subsequent 1x1
89       # convolution. Must have
90       #   in_channels * depthwise_multipler <= out_channels
91       # so that the separable convolution is not overparameterized.
92       depthwise_multiplier = min(int(depth(64) / 3), 8)
93       net = slim.separable_conv2d(
94           inputs, depth(64), [7, 7], depth_multiplier=depthwise_multiplier,
95           stride=2, weights_initializer=trunc_normal(1.0),
96           scope=end_point)
97       end_points[end_point] = net
98       if end_point == final_endpoint: return net, end_points
99       # 112 x 112 x 64
100       end_point = 'MaxPool_2a_3x3'
101       net = slim.max_pool2d(net, [3, 3], scope=end_point, stride=2)
102       end_points[end_point] = net
103       if end_point == final_endpoint: return net, end_points
104       # 56 x 56 x 64
105       end_point = 'Conv2d_2b_1x1'
106       net = slim.conv2d(net, depth(64), [1, 1], scope=end_point,
107                         weights_initializer=trunc_normal(0.1))
108       end_points[end_point] = net
109       if end_point == final_endpoint: return net, end_points
110       # 56 x 56 x 64
111       end_point = 'Conv2d_2c_3x3'
112       net = slim.conv2d(net, depth(192), [3, 3], scope=end_point)
113       end_points[end_point] = net
114       if end_point == final_endpoint: return net, end_points
115       # 56 x 56 x 192
116       end_point = 'MaxPool_3a_3x3'
117       net = slim.max_pool2d(net, [3, 3], scope=end_point, stride=2)
118       end_points[end_point] = net
119       if end_point == final_endpoint: return net, end_points
120       # 28 x 28 x 192
121       # Inception module.
122       end_point = 'Mixed_3b'
123       with tf.variable_scope(end_point):
124         with tf.variable_scope('Branch_0'):
125           branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
126         with tf.variable_scope('Branch_1'):
127           branch_1 = slim.conv2d(
128               net, depth(64), [1, 1],
129               weights_initializer=trunc_normal(0.09),
130               scope='Conv2d_0a_1x1')
131           branch_1 = slim.conv2d(branch_1, depth(64), [3, 3],
132                                  scope='Conv2d_0b_3x3')
133         with tf.variable_scope('Branch_2'):
134           branch_2 = slim.conv2d(
135               net, depth(64), [1, 1],
136               weights_initializer=trunc_normal(0.09),
137               scope='Conv2d_0a_1x1')
138           branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
139                                  scope='Conv2d_0b_3x3')
140           branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
141                                  scope='Conv2d_0c_3x3')
142         with tf.variable_scope('Branch_3'):
143           branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
144           branch_3 = slim.conv2d(
145               branch_3, depth(32), [1, 1],
146               weights_initializer=trunc_normal(0.1),
147               scope='Conv2d_0b_1x1')
148         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
149         end_points[end_point] = net
150         if end_point == final_endpoint: return net, end_points
151       # 28 x 28 x 256
152       end_point = 'Mixed_3c'
153       with tf.variable_scope(end_point):
154         with tf.variable_scope('Branch_0'):
155           branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
156         with tf.variable_scope('Branch_1'):
157           branch_1 = slim.conv2d(
158               net, depth(64), [1, 1],
159               weights_initializer=trunc_normal(0.09),
160               scope='Conv2d_0a_1x1')
161           branch_1 = slim.conv2d(branch_1, depth(96), [3, 3],
162                                  scope='Conv2d_0b_3x3')
163         with tf.variable_scope('Branch_2'):
164           branch_2 = slim.conv2d(
165               net, depth(64), [1, 1],
166               weights_initializer=trunc_normal(0.09),
167               scope='Conv2d_0a_1x1')
168           branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
169                                  scope='Conv2d_0b_3x3')
170           branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
171                                  scope='Conv2d_0c_3x3')
172         with tf.variable_scope('Branch_3'):
173           branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
174           branch_3 = slim.conv2d(
175               branch_3, depth(64), [1, 1],
176               weights_initializer=trunc_normal(0.1),
177               scope='Conv2d_0b_1x1')
178         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
179         end_points[end_point] = net
180         if end_point == final_endpoint: return net, end_points
181       # 28 x 28 x 320
182       end_point = 'Mixed_4a'
183       with tf.variable_scope(end_point):
184         with tf.variable_scope('Branch_0'):
185           branch_0 = slim.conv2d(
186               net, depth(128), [1, 1],
187               weights_initializer=trunc_normal(0.09),
188               scope='Conv2d_0a_1x1')
189           branch_0 = slim.conv2d(branch_0, depth(160), [3, 3], stride=2,
190                                  scope='Conv2d_1a_3x3')
191         with tf.variable_scope('Branch_1'):
192           branch_1 = slim.conv2d(
193               net, depth(64), [1, 1],
194               weights_initializer=trunc_normal(0.09),
195               scope='Conv2d_0a_1x1')
196           branch_1 = slim.conv2d(
197               branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3')
198           branch_1 = slim.conv2d(
199               branch_1, depth(96), [3, 3], stride=2, scope='Conv2d_1a_3x3')
200         with tf.variable_scope('Branch_2'):
201           branch_2 = slim.max_pool2d(
202               net, [3, 3], stride=2, scope='MaxPool_1a_3x3')
203         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2])
204         end_points[end_point] = net
205         if end_point == final_endpoint: return net, end_points
206       # 14 x 14 x 576
207       end_point = 'Mixed_4b'
208       with tf.variable_scope(end_point):
209         with tf.variable_scope('Branch_0'):
210           branch_0 = slim.conv2d(net, depth(224), [1, 1], scope='Conv2d_0a_1x1')
211         with tf.variable_scope('Branch_1'):
212           branch_1 = slim.conv2d(
213               net, depth(64), [1, 1],
214               weights_initializer=trunc_normal(0.09),
215               scope='Conv2d_0a_1x1')
216           branch_1 = slim.conv2d(
217               branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3')
218         with tf.variable_scope('Branch_2'):
219           branch_2 = slim.conv2d(
220               net, depth(96), [1, 1],
221               weights_initializer=trunc_normal(0.09),
222               scope='Conv2d_0a_1x1')
223           branch_2 = slim.conv2d(branch_2, depth(128), [3, 3],
224                                  scope='Conv2d_0b_3x3')
225           branch_2 = slim.conv2d(branch_2, depth(128), [3, 3],
226                                  scope='Conv2d_0c_3x3')
227         with tf.variable_scope('Branch_3'):
228           branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
229           branch_3 = slim.conv2d(
230               branch_3, depth(128), [1, 1],
231               weights_initializer=trunc_normal(0.1),
232               scope='Conv2d_0b_1x1')
233         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
234         end_points[end_point] = net
235         if end_point == final_endpoint: return net, end_points
236       # 14 x 14 x 576
237       end_point = 'Mixed_4c'
238       with tf.variable_scope(end_point):
239         with tf.variable_scope('Branch_0'):
240           branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
241         with tf.variable_scope('Branch_1'):
242           branch_1 = slim.conv2d(
243               net, depth(96), [1, 1],
244               weights_initializer=trunc_normal(0.09),
245               scope='Conv2d_0a_1x1')
246           branch_1 = slim.conv2d(branch_1, depth(128), [3, 3],
247                                  scope='Conv2d_0b_3x3')
248         with tf.variable_scope('Branch_2'):
249           branch_2 = slim.conv2d(
250               net, depth(96), [1, 1],
251               weights_initializer=trunc_normal(0.09),
252               scope='Conv2d_0a_1x1')
253           branch_2 = slim.conv2d(branch_2, depth(128), [3, 3],
254                                  scope='Conv2d_0b_3x3')
255           branch_2 = slim.conv2d(branch_2, depth(128), [3, 3],
256                                  scope='Conv2d_0c_3x3')
257         with tf.variable_scope('Branch_3'):
258           branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
259           branch_3 = slim.conv2d(
260               branch_3, depth(128), [1, 1],
261               weights_initializer=trunc_normal(0.1),
262               scope='Conv2d_0b_1x1')
263         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
264         end_points[end_point] = net
265         if end_point == final_endpoint: return net, end_points
266       # 14 x 14 x 576
267       end_point = 'Mixed_4d'
268       with tf.variable_scope(end_point):
269         with tf.variable_scope('Branch_0'):
270           branch_0 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
271         with tf.variable_scope('Branch_1'):
272           branch_1 = slim.conv2d(
273               net, depth(128), [1, 1],
274               weights_initializer=trunc_normal(0.09),
275               scope='Conv2d_0a_1x1')
276           branch_1 = slim.conv2d(branch_1, depth(160), [3, 3],
277                                  scope='Conv2d_0b_3x3')
278         with tf.variable_scope('Branch_2'):
279           branch_2 = slim.conv2d(
280               net, depth(128), [1, 1],
281               weights_initializer=trunc_normal(0.09),
282               scope='Conv2d_0a_1x1')
283           branch_2 = slim.conv2d(branch_2, depth(160), [3, 3],
284                                  scope='Conv2d_0b_3x3')
285           branch_2 = slim.conv2d(branch_2, depth(160), [3, 3],
286                                  scope='Conv2d_0c_3x3')
287         with tf.variable_scope('Branch_3'):
288           branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
289           branch_3 = slim.conv2d(
290               branch_3, depth(96), [1, 1],
291               weights_initializer=trunc_normal(0.1),
292               scope='Conv2d_0b_1x1')
293         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
294         end_points[end_point] = net
295         if end_point == final_endpoint: return net, end_points
296
297       # 14 x 14 x 576
298       end_point = 'Mixed_4e'
299       with tf.variable_scope(end_point):
300         with tf.variable_scope('Branch_0'):
301           branch_0 = slim.conv2d(net, depth(96), [1, 1], scope='Conv2d_0a_1x1')
302         with tf.variable_scope('Branch_1'):
303           branch_1 = slim.conv2d(
304               net, depth(128), [1, 1],
305               weights_initializer=trunc_normal(0.09),
306               scope='Conv2d_0a_1x1')
307           branch_1 = slim.conv2d(branch_1, depth(192), [3, 3],
308                                  scope='Conv2d_0b_3x3')
309         with tf.variable_scope('Branch_2'):
310           branch_2 = slim.conv2d(
311               net, depth(160), [1, 1],
312               weights_initializer=trunc_normal(0.09),
313               scope='Conv2d_0a_1x1')
314           branch_2 = slim.conv2d(branch_2, depth(192), [3, 3],
315                                  scope='Conv2d_0b_3x3')
316           branch_2 = slim.conv2d(branch_2, depth(192), [3, 3],
317                                  scope='Conv2d_0c_3x3')
318         with tf.variable_scope('Branch_3'):
319           branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
320           branch_3 = slim.conv2d(
321               branch_3, depth(96), [1, 1],
322               weights_initializer=trunc_normal(0.1),
323               scope='Conv2d_0b_1x1')
324         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
325         end_points[end_point] = net
326         if end_point == final_endpoint: return net, end_points
327       # 14 x 14 x 576
328       end_point = 'Mixed_5a'
329       with tf.variable_scope(end_point):
330         with tf.variable_scope('Branch_0'):
331           branch_0 = slim.conv2d(
332               net, depth(128), [1, 1],
333               weights_initializer=trunc_normal(0.09),
334               scope='Conv2d_0a_1x1')
335           branch_0 = slim.conv2d(branch_0, depth(192), [3, 3], stride=2,
336                                  scope='Conv2d_1a_3x3')
337         with tf.variable_scope('Branch_1'):
338           branch_1 = slim.conv2d(
339               net, depth(192), [1, 1],
340               weights_initializer=trunc_normal(0.09),
341               scope='Conv2d_0a_1x1')
342           branch_1 = slim.conv2d(branch_1, depth(256), [3, 3],
343                                  scope='Conv2d_0b_3x3')
344           branch_1 = slim.conv2d(branch_1, depth(256), [3, 3], stride=2,
345                                  scope='Conv2d_1a_3x3')
346         with tf.variable_scope('Branch_2'):
347           branch_2 = slim.max_pool2d(net, [3, 3], stride=2,
348                                      scope='MaxPool_1a_3x3')
349         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2])
350         end_points[end_point] = net
351         if end_point == final_endpoint: return net, end_points
352       # 7 x 7 x 1024
353       end_point = 'Mixed_5b'
354       with tf.variable_scope(end_point):
355         with tf.variable_scope('Branch_0'):
356           branch_0 = slim.conv2d(net, depth(352), [1, 1], scope='Conv2d_0a_1x1')
357         with tf.variable_scope('Branch_1'):
358           branch_1 = slim.conv2d(
359               net, depth(192), [1, 1],
360               weights_initializer=trunc_normal(0.09),
361               scope='Conv2d_0a_1x1')
362           branch_1 = slim.conv2d(branch_1, depth(320), [3, 3],
363                                  scope='Conv2d_0b_3x3')
364         with tf.variable_scope('Branch_2'):
365           branch_2 = slim.conv2d(
366               net, depth(160), [1, 1],
367               weights_initializer=trunc_normal(0.09),
368               scope='Conv2d_0a_1x1')
369           branch_2 = slim.conv2d(branch_2, depth(224), [3, 3],
370                                  scope='Conv2d_0b_3x3')
371           branch_2 = slim.conv2d(branch_2, depth(224), [3, 3],
372                                  scope='Conv2d_0c_3x3')
373         with tf.variable_scope('Branch_3'):
374           branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
375           branch_3 = slim.conv2d(
376               branch_3, depth(128), [1, 1],
377               weights_initializer=trunc_normal(0.1),
378               scope='Conv2d_0b_1x1')
379         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
380         end_points[end_point] = net
381         if end_point == final_endpoint: return net, end_points
382
383       # 7 x 7 x 1024
384       end_point = 'Mixed_5c'
385       with tf.variable_scope(end_point):
386         with tf.variable_scope('Branch_0'):
387           branch_0 = slim.conv2d(net, depth(352), [1, 1], scope='Conv2d_0a_1x1')
388         with tf.variable_scope('Branch_1'):
389           branch_1 = slim.conv2d(
390               net, depth(192), [1, 1],
391               weights_initializer=trunc_normal(0.09),
392               scope='Conv2d_0a_1x1')
393           branch_1 = slim.conv2d(branch_1, depth(320), [3, 3],
394                                  scope='Conv2d_0b_3x3')
395         with tf.variable_scope('Branch_2'):
396           branch_2 = slim.conv2d(
397               net, depth(192), [1, 1],
398               weights_initializer=trunc_normal(0.09),
399               scope='Conv2d_0a_1x1')
400           branch_2 = slim.conv2d(branch_2, depth(224), [3, 3],
401                                  scope='Conv2d_0b_3x3')
402           branch_2 = slim.conv2d(branch_2, depth(224), [3, 3],
403                                  scope='Conv2d_0c_3x3')
404         with tf.variable_scope('Branch_3'):
405           branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
406           branch_3 = slim.conv2d(
407               branch_3, depth(128), [1, 1],
408               weights_initializer=trunc_normal(0.1),
409               scope='Conv2d_0b_1x1')
410         net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
411         end_points[end_point] = net
412         if end_point == final_endpoint: return net, end_points
413     raise ValueError('Unknown final endpoint %s' % final_endpoint)
414
415
416 def inception_v2(inputs,
417                  num_classes=1000,
418                  is_training=True,
419                  dropout_keep_prob=0.8,
420                  min_depth=16,
421                  depth_multiplier=1.0,
422                  prediction_fn=slim.softmax,
423                  spatial_squeeze=True,
424                  reuse=None,
425                  scope='InceptionV2'):
426   """Inception v2 model for classification.
427
428   Constructs an Inception v2 network for classification as described in
429   http://arxiv.org/abs/1502.03167.
430
431   The default image size used to train this network is 224x224.
432
433   Args:
434     inputs: a tensor of shape [batch_size, height, width, channels].
435     num_classes: number of predicted classes.
436     is_training: whether is training or not.
437     dropout_keep_prob: the percentage of activation values that are retained.
438     min_depth: Minimum depth value (number of channels) for all convolution ops.
439       Enforced when depth_multiplier < 1, and not an active constraint when
440       depth_multiplier >= 1.
441     depth_multiplier: Float multiplier for the depth (number of channels)
442       for all convolution ops. The value must be greater than zero. Typical
443       usage will be to set this value in (0, 1) to reduce the number of
444       parameters or computation cost of the model.
445     prediction_fn: a function to get predictions out of logits.
446     spatial_squeeze: if True, logits is of shape [B, C], if false logits is
447         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
448     reuse: whether or not the network and its variables should be reused. To be
449       able to reuse 'scope' must be given.
450     scope: Optional variable_scope.
451
452   Returns:
453     logits: the pre-softmax activations, a tensor of size
454       [batch_size, num_classes]
455     end_points: a dictionary from components of the network to the corresponding
456       activation.
457
458   Raises:
459     ValueError: if final_endpoint is not set to one of the predefined values,
460                 or depth_multiplier <= 0
461   """
462   if depth_multiplier <= 0:
463     raise ValueError('depth_multiplier is not greater than zero.')
464
465   # Final pooling and prediction
466   with tf.variable_scope(scope, 'InceptionV2', [inputs, num_classes],
467                          reuse=reuse) as scope:
468     with slim.arg_scope([slim.batch_norm, slim.dropout],
469                         is_training=is_training):
470       net, end_points = inception_v2_base(
471           inputs, scope=scope, min_depth=min_depth,
472           depth_multiplier=depth_multiplier)
473       with tf.variable_scope('Logits'):
474         kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
475         net = slim.avg_pool2d(net, kernel_size, padding='VALID',
476                               scope='AvgPool_1a_{}x{}'.format(*kernel_size))
477         # 1 x 1 x 1024
478         net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
479         logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
480                              normalizer_fn=None, scope='Conv2d_1c_1x1')
481         if spatial_squeeze:
482           logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
483       end_points['Logits'] = logits
484       end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
485   return logits, end_points
486 inception_v2.default_image_size = 224
487
488
489 def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
490   """Define kernel size which is automatically reduced for small input.
491
492   If the shape of the input images is unknown at graph construction time this
493   function assumes that the input images are is large enough.
494
495   Args:
496     input_tensor: input tensor of size [batch_size, height, width, channels].
497     kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
498
499   Returns:
500     a tensor with the kernel size.
501
502   TODO(jrru): Make this function work with unknown shapes. Theoretically, this
503   can be done with the code below. Problems are two-fold: (1) If the shape was
504   known, it will be lost. (2) inception.slim.ops._two_element_tuple cannot
505   handle tensors that define the kernel size.
506       shape = tf.shape(input_tensor)
507       return = tf.pack([tf.minimum(shape[1], kernel_size[0]),
508                         tf.minimum(shape[2], kernel_size[1])])
509
510   """
511   shape = input_tensor.get_shape().as_list()
512   if shape[1] is None or shape[2] is None:
513     kernel_size_out = kernel_size
514   else:
515     kernel_size_out = [min(shape[1], kernel_size[0]),
516                        min(shape[2], kernel_size[1])]
517   return kernel_size_out
518
519
520 inception_v2_arg_scope = inception_utils.inception_arg_scope