3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
8 import tensorflow as tf
13 sys.path.append("../")
15 from data.io.image_preprocess import short_side_resize_for_inference_data
16 from libs.configs import cfgs
17 from libs.networks import build_whole_network
18 from libs.box_utils import draw_box_in_img
19 from help_utils import tools
25 def load_graph(frozen_graph_file):
27 # we parse the graph_def file
28 with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
29 graph_def = tf.GraphDef()
30 graph_def.ParseFromString(f.read())
32 # we load the graph_def in the default graph
34 with tf.Graph().as_default() as graph:
35 tf.import_graph_def(graph_def,
40 producer_op_list=None)
44 def test(frozen_graph_path, test_dir):
46 graph = load_graph(frozen_graph_path)
47 print("we are testing ====>>>>", frozen_graph_path)
49 img = graph.get_tensor_by_name("input_img:0")
50 dets = graph.get_tensor_by_name("DetResults:0")
52 with tf.Session(graph=graph) as sess:
53 for img_path in os.listdir(test_dir):
54 a_img = cv2.imread(os.path.join(test_dir, img_path))[:, :, ::-1]
56 dets_val = sess.run(dets, feed_dict={img: a_img})
58 show_indices = dets_val[:, 1] >= 0.5
59 dets_val = dets_val[show_indices]
60 final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(a_img,
61 boxes=dets_val[:, 2:],
62 labels=dets_val[:, 0],
63 scores=dets_val[:, 1])
65 final_detections[:, :, ::-1])
66 print ("%s cost time: %f" % (img_path, time.time() - st))
68 if __name__ == '__main__':
69 test('/home/yjr/PycharmProjects/Faster-RCNN_Tensorflow/output/Pbs/FasterRCNN_Res101_Pascal_Frozen.pb',
70 '/home/yjr/PycharmProjects/Faster-RCNN_Tensorflow/tools/demos')