789d2bdc3d8e3fb76662a3d7032e8adb5f91b5df
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / slim_nets / lenet.py
1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains a variant of the LeNet model definition."""
16
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20
21 import tensorflow as tf
22
23 slim = tf.contrib.slim
24
25
26 def lenet(images, num_classes=10, is_training=False,
27           dropout_keep_prob=0.5,
28           prediction_fn=slim.softmax,
29           scope='LeNet'):
30   """Creates a variant of the LeNet model.
31
32   Note that since the output is a set of 'logits', the values fall in the
33   interval of (-infinity, infinity). Consequently, to convert the outputs to a
34   probability distribution over the characters, one will need to convert them
35   using the softmax function:
36
37         logits = lenet.lenet(images, is_training=False)
38         probabilities = tf.nn.softmax(logits)
39         predictions = tf.argmax(logits, 1)
40
41   Args:
42     images: A batch of `Tensors` of size [batch_size, height, width, channels].
43     num_classes: the number of classes in the dataset.
44     is_training: specifies whether or not we're currently training the model.
45       This variable will determine the behaviour of the dropout layer.
46     dropout_keep_prob: the percentage of activation values that are retained.
47     prediction_fn: a function to get predictions out of logits.
48     scope: Optional variable_scope.
49
50   Returns:
51     logits: the pre-softmax activations, a tensor of size
52       [batch_size, `num_classes`]
53     end_points: a dictionary from components of the network to the corresponding
54       activation.
55   """
56   end_points = {}
57
58   with tf.variable_scope(scope, 'LeNet', [images, num_classes]):
59     net = slim.conv2d(images, 32, [5, 5], scope='conv1')
60     net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
61     net = slim.conv2d(net, 64, [5, 5], scope='conv2')
62     net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
63     net = slim.flatten(net)
64     end_points['Flatten'] = net
65
66     net = slim.fully_connected(net, 1024, scope='fc3')
67     net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
68                        scope='dropout3')
69     logits = slim.fully_connected(net, num_classes, activation_fn=None,
70                                   scope='fc4')
71
72   end_points['Logits'] = logits
73   end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
74
75   return logits, end_points
76 lenet.default_image_size = 28
77
78
79 def lenet_arg_scope(weight_decay=0.0):
80   """Defines the default lenet argument scope.
81
82   Args:
83     weight_decay: The weight decay to use for regularizing the model.
84
85   Returns:
86     An `arg_scope` to use for the inception v3 model.
87   """
88   with slim.arg_scope(
89       [slim.conv2d, slim.fully_connected],
90       weights_regularizer=slim.l2_regularizer(weight_decay),
91       weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
92       activation_fn=tf.nn.relu) as sc:
93     return sc