1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains a variant of the CIFAR-10 model definition."""
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
21 import tensorflow as tf
23 slim = tf.contrib.slim
25 trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev)
28 def cifarnet(images, num_classes=10, is_training=False,
29 dropout_keep_prob=0.5,
30 prediction_fn=slim.softmax,
32 """Creates a variant of the CifarNet model.
34 Note that since the output is a set of 'logits', the values fall in the
35 interval of (-infinity, infinity). Consequently, to convert the outputs to a
36 probability distribution over the characters, one will need to convert them
37 using the softmax function:
39 logits = cifarnet.cifarnet(images, is_training=False)
40 probabilities = tf.nn.softmax(logits)
41 predictions = tf.argmax(logits, 1)
44 images: A batch of `Tensors` of size [batch_size, height, width, channels].
45 num_classes: the number of classes in the dataset.
46 is_training: specifies whether or not we're currently training the model.
47 This variable will determine the behaviour of the dropout layer.
48 dropout_keep_prob: the percentage of activation values that are retained.
49 prediction_fn: a function to get predictions out of logits.
50 scope: Optional variable_scope.
53 logits: the pre-softmax activations, a tensor of size
54 [batch_size, `num_classes`]
55 end_points: a dictionary from components of the network to the corresponding
60 with tf.variable_scope(scope, 'CifarNet', [images, num_classes]):
61 net = slim.conv2d(images, 64, [5, 5], scope='conv1')
62 end_points['conv1'] = net
63 net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
64 end_points['pool1'] = net
65 net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
66 net = slim.conv2d(net, 64, [5, 5], scope='conv2')
67 end_points['conv2'] = net
68 net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
69 net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
70 end_points['pool2'] = net
71 net = slim.flatten(net)
72 end_points['Flatten'] = net
73 net = slim.fully_connected(net, 384, scope='fc3')
74 end_points['fc3'] = net
75 net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
77 net = slim.fully_connected(net, 192, scope='fc4')
78 end_points['fc4'] = net
79 logits = slim.fully_connected(net, num_classes,
80 biases_initializer=tf.zeros_initializer(),
81 weights_initializer=trunc_normal(1/192.0),
82 weights_regularizer=None,
86 end_points['Logits'] = logits
87 end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
89 return logits, end_points
90 cifarnet.default_image_size = 32
93 def cifarnet_arg_scope(weight_decay=0.004):
94 """Defines the default cifarnet argument scope.
97 weight_decay: The weight decay to use for regularizing the model.
100 An `arg_scope` to use for the inception v3 model.
104 weights_initializer=tf.truncated_normal_initializer(stddev=5e-2),
105 activation_fn=tf.nn.relu):
107 [slim.fully_connected],
108 biases_initializer=tf.constant_initializer(0.1),
109 weights_initializer=trunc_normal(0.04),
110 weights_regularizer=slim.l2_regularizer(weight_decay),
111 activation_fn=tf.nn.relu) as sc: