X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Ftools%2Finference.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Ftools%2Finference.py;h=e525d46df1b352722059602b51965b072bc631ba;hb=a785567fb9acfc68536767d20f60ba917ae85aa1;hp=0000000000000000000000000000000000000000;hpb=94a133e696b9b2a7f73544462c2714986fa7ab4a;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/tools/inference.py b/example-apps/PDD/pcb-defect-detection/tools/inference.py new file mode 100755 index 0000000..e525d46 --- /dev/null +++ b/example-apps/PDD/pcb-defect-detection/tools/inference.py @@ -0,0 +1,139 @@ +# -*- coding:utf-8 -*- + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os, sys +import tensorflow as tf +import time +import cv2 +import argparse +import numpy as np +sys.path.append("../") + +from data.io.image_preprocess import short_side_resize_for_inference_data +from libs.configs import cfgs +from libs.networks import build_whole_network +from libs.box_utils import draw_box_in_img +from help_utils import tools + + +def detect(det_net, inference_save_path, real_test_imgname_list): + + # 1. preprocess img + img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3]) # is RGB. not GBR + img_batch = tf.cast(img_plac, tf.float32) + img_batch = short_side_resize_for_inference_data(img_tensor=img_batch, + target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN, + length_limitation=cfgs.IMG_MAX_LENGTH) + img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN) + img_batch = tf.expand_dims(img_batch, axis=0) # [1, None, None, 3] + + detection_boxes, detection_scores, detection_category = det_net.build_whole_detection_network( + input_img_batch=img_batch, + gtboxes_batch=None) + + init_op = tf.group( + tf.global_variables_initializer(), + tf.local_variables_initializer() + ) + + restorer, restore_ckpt = det_net.get_restorer() + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + + with tf.Session(config=config) as sess: + sess.run(init_op) + if not restorer is None: + restorer.restore(sess, restore_ckpt) + print('restore model') + + for i, a_img_name in enumerate(real_test_imgname_list): + + raw_img = cv2.imread(a_img_name) + start = time.time() + resized_img, detected_boxes, detected_scores, detected_categories = \ + sess.run( + [img_batch, detection_boxes, detection_scores, detection_category], + feed_dict={img_plac: raw_img[:, :, ::-1]} # cv is BGR. But need RGB + ) + end = time.time() + # print("{} cost time : {} ".format(img_name, (end - start))) + + show_indices = detected_scores >= cfgs.SHOW_SCORE_THRSHOLD + show_scores = detected_scores[show_indices] + show_boxes = detected_boxes[show_indices] + show_categories = detected_categories[show_indices] + final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(np.squeeze(resized_img, 0), + boxes=show_boxes, + labels=show_categories, + scores=show_scores) + nake_name = a_img_name.split('/')[-1] + # print (inference_save_path + '/' + nake_name) + cv2.imwrite(inference_save_path + '/' + nake_name, + final_detections[:, :, ::-1]) + + tools.view_bar('{} image cost {}s'.format(a_img_name, (end - start)), i + 1, len(real_test_imgname_list)) + + +def inference(test_dir, inference_save_path): + + test_imgname_list = [os.path.join(test_dir, img_name) for img_name in os.listdir(test_dir) + if img_name.endswith(('.jpg', '.png', '.jpeg', '.tif', '.tiff'))] + assert len(test_imgname_list) != 0, 'test_dir has no imgs there.' \ + ' Note that, we only support img format of (.jpg, .png, and .tiff) ' + + faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME, + is_training=False) + detect(det_net=faster_rcnn, inference_save_path=inference_save_path, real_test_imgname_list=test_imgname_list) + + +def parse_args(): + """ + Parse input arguments + """ + parser = argparse.ArgumentParser(description='TestImgs...U need provide the test dir') + + parser.add_argument('--data_dir', dest='data_dir', + help='data path', + default='demos', type=str) + parser.add_argument('--save_dir', dest='save_dir', + help='demo imgs to save', + default='inference_results', type=str) + parser.add_argument('--GPU', dest='GPU', + help='gpu id ', + default='0', type=str) + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + + args = parser.parse_args() + + return args +if __name__ == '__main__': + + args = parse_args() + print('Called with args:') + print(args) + os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU + inference(args.data_dir, + inference_save_path=args.save_dir) + + + + + + + + + + + + + + + +