X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Ftools%2Ftest.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Ftools%2Ftest.py;h=0000000000000000000000000000000000000000;hb=3ed2c61d9d7e7916481650c41bfe5604f7db22e9;hp=86984d9d3a90c22e15895d1e939796a2f5928262;hpb=e6d40ddb2640f434a9d7d7ed99566e5e8fa60cc1;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/tools/test.py b/example-apps/PDD/pcb-defect-detection/tools/test.py deleted file mode 100755 index 86984d9..0000000 --- a/example-apps/PDD/pcb-defect-detection/tools/test.py +++ /dev/null @@ -1,153 +0,0 @@ -# -*- coding:utf-8 -*- - -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division - -import os, sys -import tensorflow as tf -import time -import cv2 -import argparse -import numpy as np -sys.path.append("../") - -from data.io.image_preprocess import short_side_resize_for_inference_data -from libs.configs import cfgs -from libs.networks import build_whole_network -from libs.box_utils import draw_box_in_img -from help_utils import tools - - -def detect(det_net, inference_save_path, real_test_imgname_list): - - # 1. preprocess img - img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3]) # is RGB. not GBR - img_batch = tf.cast(img_plac, tf.float32) - img_batch = short_side_resize_for_inference_data(img_tensor=img_batch, - target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN, - length_limitation=cfgs.IMG_MAX_LENGTH) - img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN) - img_batch = tf.expand_dims(img_batch, axis=0) # [1, None, None, 3] - - detection_boxes, detection_scores, detection_category = det_net.build_whole_detection_network( - input_img_batch=img_batch, - gtboxes_batch=None) - - init_op = tf.group( - tf.global_variables_initializer(), - tf.local_variables_initializer() - ) - - restorer, restore_ckpt = det_net.get_restorer() - - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - - with tf.Session(config=config) as sess: - sess.run(init_op) - if not restorer is None: - restorer.restore(sess, restore_ckpt) - print('restore model') - - for i, a_img_name in enumerate(real_test_imgname_list): - - raw_img = cv2.imread(a_img_name) - start = time.time() - resized_img, detected_boxes, detected_scores, detected_categories = \ - sess.run( - [img_batch, detection_boxes, detection_scores, detection_category], - feed_dict={img_plac: raw_img[:, :, ::-1]} # cv is BGR. But need RGB - ) - end = time.time() - # print("{} cost time : {} ".format(img_name, (end - start))) - - raw_h, raw_w = raw_img.shape[0], raw_img.shape[1] - - xmin, ymin, xmax, ymax = detected_boxes[:, 0], detected_boxes[:, 1], \ - detected_boxes[:, 2], detected_boxes[:, 3] - - resized_h, resized_w = resized_img.shape[1], resized_img.shape[2] - - xmin = xmin * raw_w / resized_w - xmax = xmax * raw_w / resized_w - - ymin = ymin * raw_h / resized_h - ymax = ymax * raw_h / resized_h - - detected_boxes = np.transpose(np.stack([xmin, ymin, xmax, ymax])) - - show_indices = detected_scores >= cfgs.SHOW_SCORE_THRSHOLD - show_scores = detected_scores[show_indices] - show_boxes = detected_boxes[show_indices] - show_categories = detected_categories[show_indices] - final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(raw_img - np.array(cfgs.PIXEL_MEAN), - boxes=show_boxes, - labels=show_categories, - scores=show_scores) - nake_name = a_img_name.split('/')[-1] - # print (inference_save_path + '/' + nake_name) - cv2.imwrite(inference_save_path + '/' + nake_name, - final_detections[:, :, ::-1]) - - tools.view_bar('{} image cost {}s'.format(a_img_name, (end - start)), i + 1, len(real_test_imgname_list)) - - -def test(test_dir, inference_save_path): - - test_imgname_list = [os.path.join(test_dir, img_name) for img_name in os.listdir(test_dir) - if img_name.endswith(('.jpg', '.png', '.jpeg', '.tif', '.tiff'))] - assert len(test_imgname_list) != 0, 'test_dir has no imgs there.' \ - ' Note that, we only support img format of (.jpg, .png, and .tiff) ' - - faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME, - is_training=False) - detect(det_net=faster_rcnn, inference_save_path=inference_save_path, real_test_imgname_list=test_imgname_list) - - -def parse_args(): - """ - Parse input arguments - """ - parser = argparse.ArgumentParser(description='TestImgs...U need provide the test dir') - parser.add_argument('--data_dir', dest='data_dir', - help='data path', - default='demos', type=str) - parser.add_argument('--save_dir', dest='save_dir', - help='demo imgs to save', - default='inference_results', type=str) - parser.add_argument('--GPU', dest='GPU', - help='gpu id ', - default='0', type=str) - - if len(sys.argv) == 1: - parser.print_help() - sys.exit(1) - - args = parser.parse_args() - - return args -if __name__ == '__main__': - - args = parse_args() - print('Called with args:') - print(args) - os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU - test(args.data_dir, - inference_save_path=args.save_dir) - - - - - - - - - - - - - - - -