ad78e42289bd7bfece75fa0fc3e7907db38d594a
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / export_pbs / test_exportPb.py
1 # -*- coding:utf-8 -*-
2
3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
6
7 import os, sys
8 import tensorflow as tf
9 import time
10 import cv2
11 import argparse
12 import numpy as np
13 sys.path.append("../")
14
15 from data.io.image_preprocess import short_side_resize_for_inference_data
16 from libs.configs import cfgs
17 from libs.networks import build_whole_network
18 from libs.box_utils import draw_box_in_img
19 from help_utils import tools
20
21
22
23
24
25 def load_graph(frozen_graph_file):
26
27     # we parse the graph_def file
28     with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
29         graph_def = tf.GraphDef()
30         graph_def.ParseFromString(f.read())
31
32     # we load the graph_def in the default graph
33
34     with tf.Graph().as_default() as graph:
35         tf.import_graph_def(graph_def,
36                             input_map=None,
37                             return_elements=None,
38                             name="",
39                             op_dict=None,
40                             producer_op_list=None)
41     return graph
42
43
44 def test(frozen_graph_path, test_dir):
45
46     graph = load_graph(frozen_graph_path)
47     print("we are testing ====>>>>", frozen_graph_path)
48
49     img = graph.get_tensor_by_name("input_img:0")
50     dets = graph.get_tensor_by_name("DetResults:0")
51
52     with tf.Session(graph=graph) as sess:
53         for img_path in os.listdir(test_dir):
54             a_img = cv2.imread(os.path.join(test_dir, img_path))[:, :, ::-1]
55             st = time.time()
56             dets_val = sess.run(dets, feed_dict={img: a_img})
57
58             show_indices = dets_val[:, 1] >= 0.5
59             dets_val = dets_val[show_indices]
60             final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(a_img,
61                                                                                 boxes=dets_val[:, 2:],
62                                                                                 labels=dets_val[:, 0],
63                                                                                 scores=dets_val[:, 1])
64             cv2.imwrite(img_path,
65                         final_detections[:, :, ::-1])
66             print ("%s cost time: %f" % (img_path, time.time() - st))
67
68 if __name__ == '__main__':
69     test('/home/yjr/PycharmProjects/Faster-RCNN_Tensorflow/output/Pbs/FasterRCNN_Res101_Pascal_Frozen.pb',
70          '/home/yjr/PycharmProjects/Faster-RCNN_Tensorflow/tools/demos')
71
72
73
74
75
76
77
78
79
80
81