X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fexport_pbs%2FexportPb.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fexport_pbs%2FexportPb.py;h=0000000000000000000000000000000000000000;hb=3ed2c61d9d7e7916481650c41bfe5604f7db22e9;hp=8f4b217e8fc97316e98df4872dc019c15f05e4f9;hpb=e6d40ddb2640f434a9d7d7ed99566e5e8fa60cc1;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/libs/export_pbs/exportPb.py b/example-apps/PDD/pcb-defect-detection/libs/export_pbs/exportPb.py deleted file mode 100755 index 8f4b217..0000000 --- a/example-apps/PDD/pcb-defect-detection/libs/export_pbs/exportPb.py +++ /dev/null @@ -1,87 +0,0 @@ -# -*- coding: utf-8 -*- - -from __future__ import absolute_import, print_function, division - -import os, sys -import tensorflow as tf -import tensorflow.contrib.slim as slim -from tensorflow.python.tools import freeze_graph - -sys.path.append('../../') -from data.io.image_preprocess import short_side_resize_for_inference_data -from libs.configs import cfgs -from libs.networks import build_whole_network - -CKPT_PATH = '/home/yjr/PycharmProjects/Faster-RCNN_Tensorflow/output/trained_weights/FasterRCNN_20180517/voc_200000model.ckpt' -OUT_DIR = '../../output/Pbs' -PB_NAME = 'FasterRCNN_Res101_Pascal.pb' - - -def build_detection_graph(): - # 1. preprocess img - img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3], - name='input_img') # is RGB. not GBR - raw_shape = tf.shape(img_plac) - raw_h, raw_w = tf.to_float(raw_shape[0]), tf.to_float(raw_shape[1]) - - img_batch = tf.cast(img_plac, tf.float32) - img_batch = short_side_resize_for_inference_data(img_tensor=img_batch, - target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN, - length_limitation=cfgs.IMG_MAX_LENGTH) - img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN) - img_batch = tf.expand_dims(img_batch, axis=0) # [1, None, None, 3] - - det_net = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME, - is_training=False) - - detected_boxes, detection_scores, detection_category = det_net.build_whole_detection_network( - input_img_batch=img_batch, - gtboxes_batch=None) - - xmin, ymin, xmax, ymax = detected_boxes[:, 0], detected_boxes[:, 1], \ - detected_boxes[:, 2], detected_boxes[:, 3] - - resized_shape = tf.shape(img_batch) - resized_h, resized_w = tf.to_float(resized_shape[1]), tf.to_float(resized_shape[2]) - - xmin = xmin * raw_w / resized_w - xmax = xmax * raw_w / resized_w - - ymin = ymin * raw_h / resized_h - ymax = ymax * raw_h / resized_h - - boxes = tf.transpose(tf.stack([xmin, ymin, xmax, ymax])) - dets = tf.concat([tf.reshape(detection_category, [-1, 1]), - tf.reshape(detection_scores, [-1, 1]), - boxes], axis=1, name='DetResults') - - return dets - - -def export_frozenPB(): - - tf.reset_default_graph() - - dets = build_detection_graph() - - saver = tf.train.Saver() - - with tf.Session() as sess: - print("we have restred the weights from =====>>\n", CKPT_PATH) - saver.restore(sess, CKPT_PATH) - - tf.train.write_graph(sess.graph_def, OUT_DIR, PB_NAME) - freeze_graph.freeze_graph(input_graph=os.path.join(OUT_DIR, PB_NAME), - input_saver='', - input_binary=False, - input_checkpoint=CKPT_PATH, - output_node_names="DetResults", - restore_op_name="save/restore_all", - filename_tensor_name='save/Const:0', - output_graph=os.path.join(OUT_DIR, PB_NAME.replace('.pb', '_Frozen.pb')), - clear_devices=False, - initializer_nodes='') - -if __name__ == '__main__': - os.environ["CUDA_VISIBLE_DEVICES"] = '' - export_frozenPB()