X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fnetworks%2Fslim_nets%2Fnets_factory.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fnetworks%2Fslim_nets%2Fnets_factory.py;h=7c0416167d3009a02266809658904cadad57acba;hb=a785567fb9acfc68536767d20f60ba917ae85aa1;hp=0000000000000000000000000000000000000000;hpb=94a133e696b9b2a7f73544462c2714986fa7ab4a;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/libs/networks/slim_nets/nets_factory.py b/example-apps/PDD/pcb-defect-detection/libs/networks/slim_nets/nets_factory.py new file mode 100755 index 0000000..7c04161 --- /dev/null +++ b/example-apps/PDD/pcb-defect-detection/libs/networks/slim_nets/nets_factory.py @@ -0,0 +1,112 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains a factory for building various models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import functools + +import tensorflow as tf + +from nets import alexnet +from nets import cifarnet +from nets import inception +from nets import lenet +from nets import mobilenet_v1 +from nets import overfeat +from nets import resnet_v1 +from nets import resnet_v2 +from nets import vgg + +slim = tf.contrib.slim + +networks_map = {'alexnet_v2': alexnet.alexnet_v2, + 'cifarnet': cifarnet.cifarnet, + 'overfeat': overfeat.overfeat, + 'vgg_a': vgg.vgg_a, + 'vgg_16': vgg.vgg_16, + 'vgg_19': vgg.vgg_19, + 'inception_v1': inception.inception_v1, + 'inception_v2': inception.inception_v2, + 'inception_v3': inception.inception_v3, + 'inception_v4': inception.inception_v4, + 'inception_resnet_v2': inception.inception_resnet_v2, + 'lenet': lenet.lenet, + 'resnet_v1_50': resnet_v1.resnet_v1_50, + 'resnet_v1_101': resnet_v1.resnet_v1_101, + 'resnet_v1_152': resnet_v1.resnet_v1_152, + 'resnet_v1_200': resnet_v1.resnet_v1_200, + 'resnet_v2_50': resnet_v2.resnet_v2_50, + 'resnet_v2_101': resnet_v2.resnet_v2_101, + 'resnet_v2_152': resnet_v2.resnet_v2_152, + 'resnet_v2_200': resnet_v2.resnet_v2_200, + 'mobilenet_v1': mobilenet_v1.mobilenet_v1, + } + +arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, + 'cifarnet': cifarnet.cifarnet_arg_scope, + 'overfeat': overfeat.overfeat_arg_scope, + 'vgg_a': vgg.vgg_arg_scope, + 'vgg_16': vgg.vgg_arg_scope, + 'vgg_19': vgg.vgg_arg_scope, + 'inception_v1': inception.inception_v3_arg_scope, + 'inception_v2': inception.inception_v3_arg_scope, + 'inception_v3': inception.inception_v3_arg_scope, + 'inception_v4': inception.inception_v4_arg_scope, + 'inception_resnet_v2': + inception.inception_resnet_v2_arg_scope, + 'lenet': lenet.lenet_arg_scope, + 'resnet_v1_50': resnet_v1.resnet_arg_scope, + 'resnet_v1_101': resnet_v1.resnet_arg_scope, + 'resnet_v1_152': resnet_v1.resnet_arg_scope, + 'resnet_v1_200': resnet_v1.resnet_arg_scope, + 'resnet_v2_50': resnet_v2.resnet_arg_scope, + 'resnet_v2_101': resnet_v2.resnet_arg_scope, + 'resnet_v2_152': resnet_v2.resnet_arg_scope, + 'resnet_v2_200': resnet_v2.resnet_arg_scope, + 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, + } + + +def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): + """Returns a network_fn such as `logits, end_points = network_fn(images)`. + + Args: + name: The name of the network. + num_classes: The number of classes to use for classification. + weight_decay: The l2 coefficient for the model weights. + is_training: `True` if the model is being used for training and `False` + otherwise. + + Returns: + network_fn: A function that applies the model to a batch of images. It has + the following signature: + logits, end_points = network_fn(images) + Raises: + ValueError: If network `name` is not recognized. + """ + if name not in networks_map: + raise ValueError('Name of network unknown %s' % name) + arg_scope = arg_scopes_map[name](weight_decay=weight_decay) + func = networks_map[name] + @functools.wraps(func) + def network_fn(images): + with slim.arg_scope(arg_scope): + return func(images, num_classes, is_training=is_training) + if hasattr(func, 'default_image_size'): + network_fn.default_image_size = func.default_image_size + + return network_fn