pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / tools / eval.py
1 # -*- coding:utf-8 -*-
2
3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
6
7 import os, sys
8 import tensorflow as tf
9 import time
10 import cv2
11 import pickle
12 import numpy as np
13 sys.path.append("../")
14
15 from data.io.image_preprocess import short_side_resize_for_inference_data
16 from libs.configs import cfgs
17 from libs.networks import build_whole_network
18 from libs.val_libs import voc_eval
19 from libs.box_utils import draw_box_in_img
20 import argparse
21 from help_utils import tools
22
23
24 def eval_with_plac(det_net, real_test_imgname_list, img_root, draw_imgs=False):
25
26     # 1. preprocess img
27     img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3])  # is RGB. not BGR
28     img_batch = tf.cast(img_plac, tf.float32)
29
30     img_batch = short_side_resize_for_inference_data(img_tensor=img_batch,
31                                                      target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
32                                                      length_limitation=cfgs.IMG_MAX_LENGTH)
33     img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN)
34     img_batch = tf.expand_dims(img_batch, axis=0)
35
36     detection_boxes, detection_scores, detection_category = det_net.build_whole_detection_network(
37         input_img_batch=img_batch,
38         gtboxes_batch=None)
39
40     init_op = tf.group(
41         tf.global_variables_initializer(),
42         tf.local_variables_initializer()
43     )
44
45     restorer, restore_ckpt = det_net.get_restorer()
46
47     config = tf.ConfigProto()
48     config.gpu_options.allow_growth = True
49
50     compute_time = 0
51     compute_imgnum = 0
52
53     with tf.Session(config=config) as sess:
54         sess.run(init_op)
55         if not restorer is None:
56             restorer.restore(sess, restore_ckpt)
57             print('restore model')
58
59         all_boxes = []
60         for i, a_img_name in enumerate(real_test_imgname_list):
61
62             raw_img = cv2.imread(os.path.join(img_root, a_img_name))
63             raw_h, raw_w = raw_img.shape[0], raw_img.shape[1]
64
65             start = time.time()
66             resized_img, detected_boxes, detected_scores, detected_categories = \
67                 sess.run(
68                     [img_batch, detection_boxes, detection_scores, detection_category],
69                     feed_dict={img_plac: raw_img[:, :, ::-1]}  # cv is BGR. But need RGB
70                 )
71             end = time.time()
72             compute_time = compute_time + (end - start)
73             compute_imgnum = compute_imgnum + 1
74             # print("{} cost time : {} ".format(img_name, (end - start)))
75             if draw_imgs:
76                 show_indices = detected_scores >= cfgs.SHOW_SCORE_THRSHOLD
77                 show_scores = detected_scores[show_indices]
78                 show_boxes = detected_boxes[show_indices]
79                 show_categories = detected_categories[show_indices]
80                 final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(np.squeeze(resized_img, 0),
81                                                                                     boxes=show_boxes,
82                                                                                     labels=show_categories,
83                                                                                     scores=show_scores)
84                 if not os.path.exists(cfgs.TEST_SAVE_PATH):
85                     os.makedirs(cfgs.TEST_SAVE_PATH)
86
87                 cv2.imwrite(cfgs.TEST_SAVE_PATH + '/' + a_img_name + '.jpg',
88                             final_detections[:, :, ::-1])
89
90             xmin, ymin, xmax, ymax = detected_boxes[:, 0], detected_boxes[:, 1], \
91                                      detected_boxes[:, 2], detected_boxes[:, 3]
92
93             resized_h, resized_w = resized_img.shape[1], resized_img.shape[2]
94
95             xmin = xmin * raw_w / resized_w
96             xmax = xmax * raw_w / resized_w
97
98             ymin = ymin * raw_h / resized_h
99             ymax = ymax * raw_h / resized_h
100
101             boxes = np.transpose(np.stack([xmin, ymin, xmax, ymax]))
102             dets = np.hstack((detected_categories.reshape(-1, 1),
103                               detected_scores.reshape(-1, 1),
104                               boxes))
105             all_boxes.append(dets)
106
107             tools.view_bar('{} image cost {}s'.format(a_img_name, (end - start)), i + 1, len(real_test_imgname_list))
108
109         # save_dir = os.path.join(cfgs.EVALUATE_DIR, cfgs.VERSION)
110         # if not os.path.exists(save_dir):
111         #     os.makedirs(save_dir)
112         # fw1 = open(os.path.join(save_dir, 'detections.pkl'), 'wb')
113         # pickle.dump(all_boxes, fw1)
114         print('\n average_training_time_per_image is' + str(compute_time / compute_imgnum))
115         return all_boxes
116
117
118 def eval(num_imgs, eval_dir, annotation_dir, showbox):
119
120     # with open('/home/yjr/DataSet/VOC/VOC_test/VOC2007/ImageSets/Main/aeroplane_test.txt') as f:
121     #     all_lines = f.readlines()
122     # test_imgname_list = [a_line.split()[0].strip() for a_line in all_lines]
123
124     test_imgname_list = [item for item in os.listdir(eval_dir)
125                               if item.endswith(('.jpg', 'jpeg', '.png', '.tif', '.tiff'))]
126     if num_imgs == np.inf:
127         real_test_imgname_list = test_imgname_list
128     else:
129         real_test_imgname_list = test_imgname_list[: num_imgs]
130
131     faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
132                                                        is_training=False)
133     all_boxes = eval_with_plac(det_net=faster_rcnn, real_test_imgname_list=real_test_imgname_list,
134                    img_root=eval_dir,
135                    draw_imgs=showbox)
136
137     # save_dir = os.path.join(cfgs.EVALUATE_DIR, cfgs.VERSION)
138     # if not os.path.exists(save_dir):
139     #     os.makedirs(save_dir)
140     # with open(os.path.join(save_dir, 'detections.pkl'), 'rb') as f:
141     #     all_boxes = pickle.load(f)
142     #
143     #     print(len(all_boxes))
144
145     voc_eval.voc_evaluate_detections(all_boxes=all_boxes,
146                                      test_annotation_path=annotation_dir,
147                                      test_imgid_list=real_test_imgname_list)
148
149 def parse_args():
150
151     parser = argparse.ArgumentParser('evaluate the result with Pascal2007 stdand')
152
153     parser.add_argument('--eval_imgs', dest='eval_imgs',
154                         help='evaluate imgs dir ',
155                         default='../data/pcb_test/JPEGImages', type=str)
156     parser.add_argument('--annotation_dir', dest='test_annotation_dir',
157                         help='the dir save annotations',
158                         default='../data/pcb_test/Annotations', type=str)
159     parser.add_argument('--showbox', dest='showbox',
160                         help='whether show detecion results when evaluation',
161                         default=False, type=bool)
162     parser.add_argument('--GPU', dest='GPU',
163                         help='gpu id',
164                         default='2', type=str)
165     #parser.add_argument('--eval_num', dest='eval_num',
166     #                    help='the num of eval imgs',
167     #                    default=np.inf, type=int)
168     parser.add_argument('--eval_num', dest='eval_num',
169                         help='the num of eval imgs',
170                         default=100, type=int)
171     args = parser.parse_args()
172     return args
173
174
175 if __name__ == '__main__':
176
177     args = parse_args()
178     print(20*"--")
179     print(args)
180     print(20*"--")
181     os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
182     eval(np.inf,  # use np.inf to test all the imgs. use 10 to test 10 imgs.
183          eval_dir=args.eval_imgs,
184          annotation_dir=args.test_annotation_dir,
185          showbox=args.showbox)
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201