EG version upgrade to 1.3
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / export_pbs / exportPb.py
1 # -*- coding: utf-8 -*-
2
3 from __future__ import absolute_import, print_function, division
4
5 import os, sys
6 import tensorflow as tf
7 import tensorflow.contrib.slim as slim
8 from tensorflow.python.tools import freeze_graph
9
10 sys.path.append('../../')
11 from data.io.image_preprocess import short_side_resize_for_inference_data
12 from libs.configs import cfgs
13 from libs.networks import build_whole_network
14
15 CKPT_PATH = '/home/yjr/PycharmProjects/Faster-RCNN_Tensorflow/output/trained_weights/FasterRCNN_20180517/voc_200000model.ckpt'
16 OUT_DIR = '../../output/Pbs'
17 PB_NAME = 'FasterRCNN_Res101_Pascal.pb'
18
19
20 def build_detection_graph():
21     # 1. preprocess img
22     img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3],
23                               name='input_img')  # is RGB. not GBR
24     raw_shape = tf.shape(img_plac)
25     raw_h, raw_w = tf.to_float(raw_shape[0]), tf.to_float(raw_shape[1])
26
27     img_batch = tf.cast(img_plac, tf.float32)
28     img_batch = short_side_resize_for_inference_data(img_tensor=img_batch,
29                                                      target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
30                                                      length_limitation=cfgs.IMG_MAX_LENGTH)
31     img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN)
32     img_batch = tf.expand_dims(img_batch, axis=0)  # [1, None, None, 3]
33
34     det_net = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
35                                                    is_training=False)
36
37     detected_boxes, detection_scores, detection_category = det_net.build_whole_detection_network(
38         input_img_batch=img_batch,
39         gtboxes_batch=None)
40
41     xmin, ymin, xmax, ymax = detected_boxes[:, 0], detected_boxes[:, 1], \
42                              detected_boxes[:, 2], detected_boxes[:, 3]
43
44     resized_shape = tf.shape(img_batch)
45     resized_h, resized_w = tf.to_float(resized_shape[1]), tf.to_float(resized_shape[2])
46
47     xmin = xmin * raw_w / resized_w
48     xmax = xmax * raw_w / resized_w
49
50     ymin = ymin * raw_h / resized_h
51     ymax = ymax * raw_h / resized_h
52
53     boxes = tf.transpose(tf.stack([xmin, ymin, xmax, ymax]))
54     dets = tf.concat([tf.reshape(detection_category, [-1, 1]),
55                      tf.reshape(detection_scores, [-1, 1]),
56                      boxes], axis=1, name='DetResults')
57
58     return dets
59
60
61 def export_frozenPB():
62
63     tf.reset_default_graph()
64
65     dets = build_detection_graph()
66
67     saver = tf.train.Saver()
68
69     with tf.Session() as sess:
70         print("we have restred the weights from =====>>\n", CKPT_PATH)
71         saver.restore(sess, CKPT_PATH)
72
73         tf.train.write_graph(sess.graph_def, OUT_DIR, PB_NAME)
74         freeze_graph.freeze_graph(input_graph=os.path.join(OUT_DIR, PB_NAME),
75                                   input_saver='',
76                                   input_binary=False,
77                                   input_checkpoint=CKPT_PATH,
78                                   output_node_names="DetResults",
79                                   restore_op_name="save/restore_all",
80                                   filename_tensor_name='save/Const:0',
81                                   output_graph=os.path.join(OUT_DIR, PB_NAME.replace('.pb', '_Frozen.pb')),
82                                   clear_devices=False,
83                                   initializer_nodes='')
84
85 if __name__ == '__main__':
86     os.environ["CUDA_VISIBLE_DEVICES"] = ''
87     export_frozenPB()