pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / slim_nets / cifarnet.py
1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Contains a variant of the CIFAR-10 model definition."""
16
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20
21 import tensorflow as tf
22
23 slim = tf.contrib.slim
24
25 trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev)
26
27
28 def cifarnet(images, num_classes=10, is_training=False,
29              dropout_keep_prob=0.5,
30              prediction_fn=slim.softmax,
31              scope='CifarNet'):
32   """Creates a variant of the CifarNet model.
33
34   Note that since the output is a set of 'logits', the values fall in the
35   interval of (-infinity, infinity). Consequently, to convert the outputs to a
36   probability distribution over the characters, one will need to convert them
37   using the softmax function:
38
39         logits = cifarnet.cifarnet(images, is_training=False)
40         probabilities = tf.nn.softmax(logits)
41         predictions = tf.argmax(logits, 1)
42
43   Args:
44     images: A batch of `Tensors` of size [batch_size, height, width, channels].
45     num_classes: the number of classes in the dataset.
46     is_training: specifies whether or not we're currently training the model.
47       This variable will determine the behaviour of the dropout layer.
48     dropout_keep_prob: the percentage of activation values that are retained.
49     prediction_fn: a function to get predictions out of logits.
50     scope: Optional variable_scope.
51
52   Returns:
53     logits: the pre-softmax activations, a tensor of size
54       [batch_size, `num_classes`]
55     end_points: a dictionary from components of the network to the corresponding
56       activation.
57   """
58   end_points = {}
59
60   with tf.variable_scope(scope, 'CifarNet', [images, num_classes]):
61     net = slim.conv2d(images, 64, [5, 5], scope='conv1')
62     end_points['conv1'] = net
63     net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
64     end_points['pool1'] = net
65     net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
66     net = slim.conv2d(net, 64, [5, 5], scope='conv2')
67     end_points['conv2'] = net
68     net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
69     net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
70     end_points['pool2'] = net
71     net = slim.flatten(net)
72     end_points['Flatten'] = net
73     net = slim.fully_connected(net, 384, scope='fc3')
74     end_points['fc3'] = net
75     net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
76                        scope='dropout3')
77     net = slim.fully_connected(net, 192, scope='fc4')
78     end_points['fc4'] = net
79     logits = slim.fully_connected(net, num_classes,
80                                   biases_initializer=tf.zeros_initializer(),
81                                   weights_initializer=trunc_normal(1/192.0),
82                                   weights_regularizer=None,
83                                   activation_fn=None,
84                                   scope='logits')
85
86     end_points['Logits'] = logits
87     end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
88
89   return logits, end_points
90 cifarnet.default_image_size = 32
91
92
93 def cifarnet_arg_scope(weight_decay=0.004):
94   """Defines the default cifarnet argument scope.
95
96   Args:
97     weight_decay: The weight decay to use for regularizing the model.
98
99   Returns:
100     An `arg_scope` to use for the inception v3 model.
101   """
102   with slim.arg_scope(
103       [slim.conv2d],
104       weights_initializer=tf.truncated_normal_initializer(stddev=5e-2),
105       activation_fn=tf.nn.relu):
106     with slim.arg_scope(
107         [slim.fully_connected],
108         biases_initializer=tf.constant_initializer(0.1),
109         weights_initializer=trunc_normal(0.04),
110         weights_regularizer=slim.l2_regularizer(weight_decay),
111         activation_fn=tf.nn.relu) as sc:
112       return sc