54cf9b80668d6714d83270eb5cdd627e5fca1e05
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / detection_oprations / proposal_opr.py
1 # encoding: utf-8
2 """
3 @author: zeming li
4 @contact: zengarden2009@gmail.com
5 """
6
7 from libs.configs import cfgs
8 from libs.box_utils import encode_and_decode
9 from libs.box_utils import boxes_utils
10 import tensorflow as tf
11 import numpy as np
12
13
14 def postprocess_rpn_proposals(rpn_bbox_pred, rpn_cls_prob, img_shape, anchors, is_training):
15     '''
16
17     :param rpn_bbox_pred: [-1, 4]
18     :param rpn_cls_prob: [-1, 2]
19     :param img_shape:
20     :param anchors:[-1, 4]
21     :param is_training:
22     :return:
23     '''
24
25     if is_training:
26         pre_nms_topN = cfgs.RPN_TOP_K_NMS_TRAIN
27         post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TARIN
28         # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TRAIN
29         # post_nms_topN = pre_nms_topN
30     else:
31         pre_nms_topN = cfgs.RPN_TOP_K_NMS_TEST
32         post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TEST
33         # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TEST
34         # post_nms_topN = pre_nms_topN
35
36     nms_thresh = cfgs.RPN_NMS_IOU_THRESHOLD
37
38     cls_prob = rpn_cls_prob[:, 1]
39
40     # 1. decode boxes
41     decode_boxes = encode_and_decode.decode_boxes(encoded_boxes=rpn_bbox_pred,
42                                                   reference_boxes=anchors,
43                                                   scale_factors=cfgs.ANCHOR_SCALE_FACTORS)
44
45     # 2. clip to img boundaries
46     decode_boxes = boxes_utils.clip_boxes_to_img_boundaries(decode_boxes=decode_boxes,
47                                                             img_shape=img_shape)
48
49     # 3. get top N to NMS
50     if pre_nms_topN > 0:
51         pre_nms_topN = tf.minimum(pre_nms_topN, tf.shape(decode_boxes)[0], name='avoid_unenough_boxes')
52         cls_prob, top_k_indices = tf.nn.top_k(cls_prob, k=pre_nms_topN)
53         decode_boxes = tf.gather(decode_boxes, top_k_indices)
54
55     # 4. NMS
56     keep = tf.image.non_max_suppression(
57         boxes=decode_boxes,
58         scores=cls_prob,
59         max_output_size=post_nms_topN,
60         iou_threshold=nms_thresh)
61
62     final_boxes = tf.gather(decode_boxes, keep)
63     final_probs = tf.gather(cls_prob, keep)
64
65     return final_boxes, final_probs
66