X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fnetworks%2Fslim_nets%2Fnets_factory_test.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fnetworks%2Fslim_nets%2Fnets_factory_test.py;h=0000000000000000000000000000000000000000;hb=3ed2c61d9d7e7916481650c41bfe5604f7db22e9;hp=b4ab1f822c9e85ab41b25e57589479e95377de18;hpb=e6d40ddb2640f434a9d7d7ed99566e5e8fa60cc1;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/libs/networks/slim_nets/nets_factory_test.py b/example-apps/PDD/pcb-defect-detection/libs/networks/slim_nets/nets_factory_test.py deleted file mode 100755 index b4ab1f8..0000000 --- a/example-apps/PDD/pcb-defect-detection/libs/networks/slim_nets/nets_factory_test.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2016 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for slim.inception.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from nets import nets_factory - -slim = tf.contrib.slim - - -class NetworksTest(tf.test.TestCase): - - def testGetNetworkFn(self): - batch_size = 5 - num_classes = 1000 - for net in nets_factory.networks_map: - with self.test_session(): - net_fn = nets_factory.get_network_fn(net, num_classes) - # Most networks use 224 as their default_image_size - image_size = getattr(net_fn, 'default_image_size', 224) - inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) - logits, end_points = net_fn(inputs) - self.assertTrue(isinstance(logits, tf.Tensor)) - self.assertTrue(isinstance(end_points, dict)) - self.assertEqual(logits.get_shape().as_list()[0], batch_size) - self.assertEqual(logits.get_shape().as_list()[-1], num_classes) - - def testGetNetworkFnArgScope(self): - batch_size = 5 - num_classes = 10 - net = 'cifarnet' - with self.test_session(use_gpu=True): - net_fn = nets_factory.get_network_fn(net, num_classes) - image_size = getattr(net_fn, 'default_image_size', 224) - with slim.arg_scope([slim.model_variable, slim.variable], - device='/CPU:0'): - inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) - net_fn(inputs) - weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'CifarNet/conv1')[0] - self.assertDeviceEqual('/CPU:0', weights.device) - -if __name__ == '__main__': - tf.test.main()