EG version upgrade to 1.3
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / tools / inference_for_coco.py
1 # -*- coding:utf-8 -*-
2
3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
6
7 import os, sys
8 import tensorflow as tf
9 import time
10 import cv2
11 import pickle
12 import numpy as np
13 sys.path.append("../")
14 sys.path.insert(0, '/home/yjr/PycharmProjects/Faster-RCNN_TF/data/lib_coco/PythonAPI')
15 from data.io.image_preprocess import short_side_resize_for_inference_data
16 from libs.configs import cfgs
17 from libs.networks import build_whole_network
18 from libs.val_libs import voc_eval
19 from libs.box_utils import draw_box_in_img
20 from libs.label_name_dict.coco_dict import LABEL_NAME_MAP, classes_originID
21 from help_utils import tools
22 from data.lib_coco.PythonAPI.pycocotools.coco import COCO
23 import json
24
25 os.environ["CUDA_VISIBLE_DEVICES"] = cfgs.GPU_GROUP
26
27
28 def eval_with_plac(det_net, imgId_list, coco, out_json_root, draw_imgs=False):
29
30     # 1. preprocess img
31     img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3])  # is RGB. not GBR
32     img_batch = tf.cast(img_plac, tf.float32)
33
34     img_batch = short_side_resize_for_inference_data(img_tensor=img_batch,
35                                                      target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
36                                                      length_limitation=cfgs.IMG_MAX_LENGTH)
37     img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN)
38     img_batch = tf.expand_dims(img_batch, axis=0)
39
40     detection_boxes, detection_scores, detection_category = det_net.build_whole_detection_network(
41         input_img_batch=img_batch,
42         gtboxes_batch=None)
43
44     init_op = tf.group(
45         tf.global_variables_initializer(),
46         tf.local_variables_initializer()
47     )
48
49     restorer, restore_ckpt = det_net.get_restorer()
50
51     config = tf.ConfigProto()
52     config.gpu_options.allow_growth = True
53
54     # coco_test_results = []
55
56     with tf.Session(config=config) as sess:
57         sess.run(init_op)
58         if not restorer is None:
59             restorer.restore(sess, restore_ckpt)
60             print('restore model')
61
62         for i, imgid in enumerate(imgId_list):
63             imgname = coco.loadImgs(ids=[imgid])[0]['file_name']
64             raw_img = cv2.imread(os.path.join("/home/yjr/DataSet/COCO/2017/test2017", imgname))
65
66             raw_h, raw_w = raw_img.shape[0], raw_img.shape[1]
67             start = time.time()
68             resized_img, detected_boxes, detected_scores, detected_categories = \
69                 sess.run(
70                     [img_batch, detection_boxes, detection_scores, detection_category],
71                     feed_dict={img_plac: raw_img[:, :, ::-1]}  # cv is BGR. But need RGB
72                 )
73             end = time.time()
74
75             if draw_imgs:
76                 show_indices = detected_scores >= cfgs.SHOW_SCORE_THRSHOLD
77                 show_scores = detected_scores[show_indices]
78                 show_boxes = detected_boxes[show_indices]
79                 show_categories = detected_categories[show_indices]
80                 final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(np.squeeze(resized_img, 0),
81                                                                                     boxes=show_boxes,
82                                                                                     labels=show_categories,
83                                                                                     scores=show_scores)
84                 cv2.imwrite(cfgs.TEST_SAVE_PATH + '/' + str(imgid) + '.jpg',
85                             final_detections[:, :, ::-1])
86
87             xmin, ymin, xmax, ymax = detected_boxes[:, 0], detected_boxes[:, 1], \
88                                      detected_boxes[:, 2], detected_boxes[:, 3]
89
90             resized_h, resized_w = resized_img.shape[1], resized_img.shape[2]
91
92             xmin = xmin * raw_w / resized_w
93             xmax = xmax * raw_w / resized_w
94
95             ymin = ymin * raw_h / resized_h
96             ymax = ymax * raw_h / resized_h
97
98             boxes = np.transpose(np.stack([xmin, ymin, xmax-xmin, ymax-ymin]))
99
100             dets = np.hstack((detected_categories.reshape(-1, 1),
101                               detected_scores.reshape(-1, 1),
102                               boxes))
103
104             a_img_detect_result = []
105             for a_det in dets:
106                 label, score, bbox = a_det[0], a_det[1], a_det[2:]
107                 cat_id = classes_originID[LABEL_NAME_MAP[label]]
108                 if score<0.00001:
109                    continue
110                 det_object = {"image_id": imgid,
111                               "category_id": cat_id,
112                               "bbox": bbox.tolist(),
113                               "score": float(score)}
114                 # print (det_object)
115                 a_img_detect_result.append(det_object)
116             f = open(os.path.join(out_json_root, 'each_img', str(imgid)+'.json'), 'w')
117             json.dump(a_img_detect_result, f)  # , indent=4
118             f.close()
119             del a_img_detect_result
120             del dets
121             del boxes
122             del resized_img
123             del raw_img
124             tools.view_bar('{} image cost {}s'.format(imgid, (end - start)), i + 1, len(imgId_list))
125
126
127 def eval(num_imgs):
128
129
130    # annotation_path = '/home/yjr/DataSet/COCO/2017/test_annotations/image_info_test2017.json'
131     annotation_path = '/home/yjr/DataSet/COCO/2017/test_annotations/image_info_test-dev2017.json'
132     # annotation_path = '/home/yjr/DataSet/COCO/2017/annotations/instances_train2017.json'
133     print("load coco .... it will cost about 17s..")
134     coco = COCO(annotation_path)
135
136     imgId_list = coco.getImgIds()
137
138     if num_imgs !=np.inf:
139         imgId_list = imgId_list[: num_imgs]
140
141     faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
142                                                        is_training=False)
143     save_dir = os.path.join(cfgs.EVALUATE_DIR, cfgs.VERSION)
144     eval_with_plac(det_net=faster_rcnn, coco=coco, imgId_list=imgId_list, out_json_root=save_dir,
145                    draw_imgs=True)
146     print("each img over**************")
147
148     final_detections = []
149     with open(os.path.join(save_dir, 'coco2017test_results.json'), 'w') as wf:
150         for imgid in imgId_list:
151             f = open(os.path.join(save_dir, 'each_img', str(imgid)+'.json'))
152             tmp_list = json.load(f)
153             # print (type(tmp_list))
154             final_detections.extend(tmp_list)
155             del tmp_list
156             f.close()
157         json.dump(final_detections, wf)
158
159
160 if __name__ == '__main__':
161
162     eval(np.inf)
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179