pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / export_pbs / exportPb.py
diff --git a/example-apps/PDD/pcb-defect-detection/libs/export_pbs/exportPb.py b/example-apps/PDD/pcb-defect-detection/libs/export_pbs/exportPb.py
new file mode 100755 (executable)
index 0000000..8f4b217
--- /dev/null
@@ -0,0 +1,87 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, print_function, division
+
+import os, sys
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+from tensorflow.python.tools import freeze_graph
+
+sys.path.append('../../')
+from data.io.image_preprocess import short_side_resize_for_inference_data
+from libs.configs import cfgs
+from libs.networks import build_whole_network
+
+CKPT_PATH = '/home/yjr/PycharmProjects/Faster-RCNN_Tensorflow/output/trained_weights/FasterRCNN_20180517/voc_200000model.ckpt'
+OUT_DIR = '../../output/Pbs'
+PB_NAME = 'FasterRCNN_Res101_Pascal.pb'
+
+
+def build_detection_graph():
+    # 1. preprocess img
+    img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3],
+                              name='input_img')  # is RGB. not GBR
+    raw_shape = tf.shape(img_plac)
+    raw_h, raw_w = tf.to_float(raw_shape[0]), tf.to_float(raw_shape[1])
+
+    img_batch = tf.cast(img_plac, tf.float32)
+    img_batch = short_side_resize_for_inference_data(img_tensor=img_batch,
+                                                     target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
+                                                     length_limitation=cfgs.IMG_MAX_LENGTH)
+    img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN)
+    img_batch = tf.expand_dims(img_batch, axis=0)  # [1, None, None, 3]
+
+    det_net = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
+                                                   is_training=False)
+
+    detected_boxes, detection_scores, detection_category = det_net.build_whole_detection_network(
+        input_img_batch=img_batch,
+        gtboxes_batch=None)
+
+    xmin, ymin, xmax, ymax = detected_boxes[:, 0], detected_boxes[:, 1], \
+                             detected_boxes[:, 2], detected_boxes[:, 3]
+
+    resized_shape = tf.shape(img_batch)
+    resized_h, resized_w = tf.to_float(resized_shape[1]), tf.to_float(resized_shape[2])
+
+    xmin = xmin * raw_w / resized_w
+    xmax = xmax * raw_w / resized_w
+
+    ymin = ymin * raw_h / resized_h
+    ymax = ymax * raw_h / resized_h
+
+    boxes = tf.transpose(tf.stack([xmin, ymin, xmax, ymax]))
+    dets = tf.concat([tf.reshape(detection_category, [-1, 1]),
+                     tf.reshape(detection_scores, [-1, 1]),
+                     boxes], axis=1, name='DetResults')
+
+    return dets
+
+
+def export_frozenPB():
+
+    tf.reset_default_graph()
+
+    dets = build_detection_graph()
+
+    saver = tf.train.Saver()
+
+    with tf.Session() as sess:
+        print("we have restred the weights from =====>>\n", CKPT_PATH)
+        saver.restore(sess, CKPT_PATH)
+
+        tf.train.write_graph(sess.graph_def, OUT_DIR, PB_NAME)
+        freeze_graph.freeze_graph(input_graph=os.path.join(OUT_DIR, PB_NAME),
+                                  input_saver='',
+                                  input_binary=False,
+                                  input_checkpoint=CKPT_PATH,
+                                  output_node_names="DetResults",
+                                  restore_op_name="save/restore_all",
+                                  filename_tensor_name='save/Const:0',
+                                  output_graph=os.path.join(OUT_DIR, PB_NAME.replace('.pb', '_Frozen.pb')),
+                                  clear_devices=False,
+                                  initializer_nodes='')
+
+if __name__ == '__main__':
+    os.environ["CUDA_VISIBLE_DEVICES"] = ''
+    export_frozenPB()