EG version upgrade to 1.3
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / tools / test.py
1 # -*- coding:utf-8 -*-
2
3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
6
7 import os, sys
8 import tensorflow as tf
9 import time
10 import cv2
11 import argparse
12 import numpy as np
13 sys.path.append("../")
14
15 from data.io.image_preprocess import short_side_resize_for_inference_data
16 from libs.configs import cfgs
17 from libs.networks import build_whole_network
18 from libs.box_utils import draw_box_in_img
19 from help_utils import tools
20
21
22 def detect(det_net, inference_save_path, real_test_imgname_list):
23
24     # 1. preprocess img
25     img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3])  # is RGB. not GBR
26     img_batch = tf.cast(img_plac, tf.float32)
27     img_batch = short_side_resize_for_inference_data(img_tensor=img_batch,
28                                                      target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
29                                                      length_limitation=cfgs.IMG_MAX_LENGTH)
30     img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN)
31     img_batch = tf.expand_dims(img_batch, axis=0) # [1, None, None, 3]
32
33     detection_boxes, detection_scores, detection_category = det_net.build_whole_detection_network(
34         input_img_batch=img_batch,
35         gtboxes_batch=None)
36
37     init_op = tf.group(
38         tf.global_variables_initializer(),
39         tf.local_variables_initializer()
40     )
41
42     restorer, restore_ckpt = det_net.get_restorer()
43
44     config = tf.ConfigProto()
45     config.gpu_options.allow_growth = True
46
47     with tf.Session(config=config) as sess:
48         sess.run(init_op)
49         if not restorer is None:
50             restorer.restore(sess, restore_ckpt)
51             print('restore model')
52
53         for i, a_img_name in enumerate(real_test_imgname_list):
54
55             raw_img = cv2.imread(a_img_name)
56             start = time.time()
57             resized_img, detected_boxes, detected_scores, detected_categories = \
58                 sess.run(
59                     [img_batch, detection_boxes, detection_scores, detection_category],
60                     feed_dict={img_plac: raw_img[:, :, ::-1]}  # cv is BGR. But need RGB
61                 )
62             end = time.time()
63             # print("{} cost time : {} ".format(img_name, (end - start)))
64
65             raw_h, raw_w = raw_img.shape[0], raw_img.shape[1]
66
67             xmin, ymin, xmax, ymax = detected_boxes[:, 0], detected_boxes[:, 1], \
68                                      detected_boxes[:, 2], detected_boxes[:, 3]
69
70             resized_h, resized_w = resized_img.shape[1], resized_img.shape[2]
71
72             xmin = xmin * raw_w / resized_w
73             xmax = xmax * raw_w / resized_w
74
75             ymin = ymin * raw_h / resized_h
76             ymax = ymax * raw_h / resized_h
77
78             detected_boxes = np.transpose(np.stack([xmin, ymin, xmax, ymax]))
79
80             show_indices = detected_scores >= cfgs.SHOW_SCORE_THRSHOLD
81             show_scores = detected_scores[show_indices]
82             show_boxes = detected_boxes[show_indices]
83             show_categories = detected_categories[show_indices]
84             final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(raw_img - np.array(cfgs.PIXEL_MEAN),
85                                                                                 boxes=show_boxes,
86                                                                                 labels=show_categories,
87                                                                                 scores=show_scores)
88             nake_name = a_img_name.split('/')[-1]
89             # print (inference_save_path + '/' + nake_name)
90             cv2.imwrite(inference_save_path + '/' + nake_name,
91                         final_detections[:, :, ::-1])
92
93             tools.view_bar('{} image cost {}s'.format(a_img_name, (end - start)), i + 1, len(real_test_imgname_list))
94
95
96 def test(test_dir, inference_save_path):
97
98     test_imgname_list = [os.path.join(test_dir, img_name) for img_name in os.listdir(test_dir)
99                                                           if img_name.endswith(('.jpg', '.png', '.jpeg', '.tif', '.tiff'))]
100     assert len(test_imgname_list) != 0, 'test_dir has no imgs there.' \
101                                         ' Note that, we only support img format of (.jpg, .png, and .tiff) '
102
103     faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
104                                                        is_training=False)
105     detect(det_net=faster_rcnn, inference_save_path=inference_save_path, real_test_imgname_list=test_imgname_list)
106
107
108 def parse_args():
109     """
110     Parse input arguments
111     """
112     parser = argparse.ArgumentParser(description='TestImgs...U need provide the test dir')
113     parser.add_argument('--data_dir', dest='data_dir',
114                         help='data path',
115                         default='demos', type=str)
116     parser.add_argument('--save_dir', dest='save_dir',
117                         help='demo imgs to save',
118                         default='inference_results', type=str)
119     parser.add_argument('--GPU', dest='GPU',
120                         help='gpu id ',
121                         default='0', type=str)
122
123     if len(sys.argv) == 1:
124         parser.print_help()
125         sys.exit(1)
126
127     args = parser.parse_args()
128
129     return args
130 if __name__ == '__main__':
131
132     args = parse_args()
133     print('Called with args:')
134     print(args)
135     os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
136     test(args.data_dir,
137          inference_save_path=args.save_dir)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153