+++ /dev/null
-# -*- coding:utf-8 -*-
-
-from __future__ import absolute_import
-from __future__ import print_function
-from __future__ import division
-
-import os, sys
-import tensorflow as tf
-import time
-import cv2
-import argparse
-import numpy as np
-sys.path.append("../")
-
-from data.io.image_preprocess import short_side_resize_for_inference_data
-from libs.configs import cfgs
-from libs.networks import build_whole_network
-from libs.box_utils import draw_box_in_img
-from help_utils import tools
-
-
-def detect(det_net, inference_save_path, real_test_imgname_list):
-
- # 1. preprocess img
- img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3]) # is RGB. not GBR
- img_batch = tf.cast(img_plac, tf.float32)
- img_batch = short_side_resize_for_inference_data(img_tensor=img_batch,
- target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
- length_limitation=cfgs.IMG_MAX_LENGTH)
- img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN)
- img_batch = tf.expand_dims(img_batch, axis=0) # [1, None, None, 3]
-
- detection_boxes, detection_scores, detection_category = det_net.build_whole_detection_network(
- input_img_batch=img_batch,
- gtboxes_batch=None)
-
- init_op = tf.group(
- tf.global_variables_initializer(),
- tf.local_variables_initializer()
- )
-
- restorer, restore_ckpt = det_net.get_restorer()
-
- config = tf.ConfigProto()
- config.gpu_options.allow_growth = True
-
- with tf.Session(config=config) as sess:
- sess.run(init_op)
- if not restorer is None:
- restorer.restore(sess, restore_ckpt)
- print('restore model')
-
- for i, a_img_name in enumerate(real_test_imgname_list):
-
- raw_img = cv2.imread(a_img_name)
- start = time.time()
- resized_img, detected_boxes, detected_scores, detected_categories = \
- sess.run(
- [img_batch, detection_boxes, detection_scores, detection_category],
- feed_dict={img_plac: raw_img[:, :, ::-1]} # cv is BGR. But need RGB
- )
- end = time.time()
- # print("{} cost time : {} ".format(img_name, (end - start)))
-
- raw_h, raw_w = raw_img.shape[0], raw_img.shape[1]
-
- xmin, ymin, xmax, ymax = detected_boxes[:, 0], detected_boxes[:, 1], \
- detected_boxes[:, 2], detected_boxes[:, 3]
-
- resized_h, resized_w = resized_img.shape[1], resized_img.shape[2]
-
- xmin = xmin * raw_w / resized_w
- xmax = xmax * raw_w / resized_w
-
- ymin = ymin * raw_h / resized_h
- ymax = ymax * raw_h / resized_h
-
- detected_boxes = np.transpose(np.stack([xmin, ymin, xmax, ymax]))
-
- show_indices = detected_scores >= cfgs.SHOW_SCORE_THRSHOLD
- show_scores = detected_scores[show_indices]
- show_boxes = detected_boxes[show_indices]
- show_categories = detected_categories[show_indices]
- final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(raw_img - np.array(cfgs.PIXEL_MEAN),
- boxes=show_boxes,
- labels=show_categories,
- scores=show_scores)
- nake_name = a_img_name.split('/')[-1]
- # print (inference_save_path + '/' + nake_name)
- cv2.imwrite(inference_save_path + '/' + nake_name,
- final_detections[:, :, ::-1])
-
- tools.view_bar('{} image cost {}s'.format(a_img_name, (end - start)), i + 1, len(real_test_imgname_list))
-
-
-def test(test_dir, inference_save_path):
-
- test_imgname_list = [os.path.join(test_dir, img_name) for img_name in os.listdir(test_dir)
- if img_name.endswith(('.jpg', '.png', '.jpeg', '.tif', '.tiff'))]
- assert len(test_imgname_list) != 0, 'test_dir has no imgs there.' \
- ' Note that, we only support img format of (.jpg, .png, and .tiff) '
-
- faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
- is_training=False)
- detect(det_net=faster_rcnn, inference_save_path=inference_save_path, real_test_imgname_list=test_imgname_list)
-
-
-def parse_args():
- """
- Parse input arguments
- """
- parser = argparse.ArgumentParser(description='TestImgs...U need provide the test dir')
- parser.add_argument('--data_dir', dest='data_dir',
- help='data path',
- default='demos', type=str)
- parser.add_argument('--save_dir', dest='save_dir',
- help='demo imgs to save',
- default='inference_results', type=str)
- parser.add_argument('--GPU', dest='GPU',
- help='gpu id ',
- default='0', type=str)
-
- if len(sys.argv) == 1:
- parser.print_help()
- sys.exit(1)
-
- args = parser.parse_args()
-
- return args
-if __name__ == '__main__':
-
- args = parse_args()
- print('Called with args:')
- print(args)
- os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
- test(args.data_dir,
- inference_save_path=args.save_dir)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-