4 @contact: zengarden2009@gmail.com
7 from libs.configs import cfgs
8 from libs.box_utils import encode_and_decode
9 from libs.box_utils import boxes_utils
10 import tensorflow as tf
14 def postprocess_rpn_proposals(rpn_bbox_pred, rpn_cls_prob, img_shape, anchors, is_training):
17 :param rpn_bbox_pred: [-1, 4]
18 :param rpn_cls_prob: [-1, 2]
20 :param anchors:[-1, 4]
26 pre_nms_topN = cfgs.RPN_TOP_K_NMS_TRAIN
27 post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TARIN
28 # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TRAIN
29 # post_nms_topN = pre_nms_topN
31 pre_nms_topN = cfgs.RPN_TOP_K_NMS_TEST
32 post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TEST
33 # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TEST
34 # post_nms_topN = pre_nms_topN
36 nms_thresh = cfgs.RPN_NMS_IOU_THRESHOLD
38 cls_prob = rpn_cls_prob[:, 1]
41 decode_boxes = encode_and_decode.decode_boxes(encoded_boxes=rpn_bbox_pred,
42 reference_boxes=anchors,
43 scale_factors=cfgs.ANCHOR_SCALE_FACTORS)
45 # 2. clip to img boundaries
46 decode_boxes = boxes_utils.clip_boxes_to_img_boundaries(decode_boxes=decode_boxes,
51 pre_nms_topN = tf.minimum(pre_nms_topN, tf.shape(decode_boxes)[0], name='avoid_unenough_boxes')
52 cls_prob, top_k_indices = tf.nn.top_k(cls_prob, k=pre_nms_topN)
53 decode_boxes = tf.gather(decode_boxes, top_k_indices)
56 keep = tf.image.non_max_suppression(
59 max_output_size=post_nms_topN,
60 iou_threshold=nms_thresh)
62 final_boxes = tf.gather(decode_boxes, keep)
63 final_probs = tf.gather(cls_prob, keep)
65 return final_boxes, final_probs