# encoding: utf-8 """ @author: zeming li @contact: zengarden2009@gmail.com """ from libs.configs import cfgs from libs.box_utils import encode_and_decode from libs.box_utils import boxes_utils import tensorflow as tf import numpy as np def postprocess_rpn_proposals(rpn_bbox_pred, rpn_cls_prob, img_shape, anchors, is_training): ''' :param rpn_bbox_pred: [-1, 4] :param rpn_cls_prob: [-1, 2] :param img_shape: :param anchors:[-1, 4] :param is_training: :return: ''' if is_training: pre_nms_topN = cfgs.RPN_TOP_K_NMS_TRAIN post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TARIN # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TRAIN # post_nms_topN = pre_nms_topN else: pre_nms_topN = cfgs.RPN_TOP_K_NMS_TEST post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TEST # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TEST # post_nms_topN = pre_nms_topN nms_thresh = cfgs.RPN_NMS_IOU_THRESHOLD cls_prob = rpn_cls_prob[:, 1] # 1. decode boxes decode_boxes = encode_and_decode.decode_boxes(encoded_boxes=rpn_bbox_pred, reference_boxes=anchors, scale_factors=cfgs.ANCHOR_SCALE_FACTORS) # 2. clip to img boundaries decode_boxes = boxes_utils.clip_boxes_to_img_boundaries(decode_boxes=decode_boxes, img_shape=img_shape) # 3. get top N to NMS if pre_nms_topN > 0: pre_nms_topN = tf.minimum(pre_nms_topN, tf.shape(decode_boxes)[0], name='avoid_unenough_boxes') cls_prob, top_k_indices = tf.nn.top_k(cls_prob, k=pre_nms_topN) decode_boxes = tf.gather(decode_boxes, top_k_indices) # 4. NMS keep = tf.image.non_max_suppression( boxes=decode_boxes, scores=cls_prob, max_output_size=post_nms_topN, iou_threshold=nms_thresh) final_boxes = tf.gather(decode_boxes, keep) final_probs = tf.gather(cls_prob, keep) return final_boxes, final_probs