X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fdetection_oprations%2Fproposal_opr.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fdetection_oprations%2Fproposal_opr.py;h=54cf9b80668d6714d83270eb5cdd627e5fca1e05;hb=a785567fb9acfc68536767d20f60ba917ae85aa1;hp=0000000000000000000000000000000000000000;hpb=94a133e696b9b2a7f73544462c2714986fa7ab4a;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/libs/detection_oprations/proposal_opr.py b/example-apps/PDD/pcb-defect-detection/libs/detection_oprations/proposal_opr.py new file mode 100755 index 0000000..54cf9b8 --- /dev/null +++ b/example-apps/PDD/pcb-defect-detection/libs/detection_oprations/proposal_opr.py @@ -0,0 +1,66 @@ +# encoding: utf-8 +""" +@author: zeming li +@contact: zengarden2009@gmail.com +""" + +from libs.configs import cfgs +from libs.box_utils import encode_and_decode +from libs.box_utils import boxes_utils +import tensorflow as tf +import numpy as np + + +def postprocess_rpn_proposals(rpn_bbox_pred, rpn_cls_prob, img_shape, anchors, is_training): + ''' + + :param rpn_bbox_pred: [-1, 4] + :param rpn_cls_prob: [-1, 2] + :param img_shape: + :param anchors:[-1, 4] + :param is_training: + :return: + ''' + + if is_training: + pre_nms_topN = cfgs.RPN_TOP_K_NMS_TRAIN + post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TARIN + # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TRAIN + # post_nms_topN = pre_nms_topN + else: + pre_nms_topN = cfgs.RPN_TOP_K_NMS_TEST + post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TEST + # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TEST + # post_nms_topN = pre_nms_topN + + nms_thresh = cfgs.RPN_NMS_IOU_THRESHOLD + + cls_prob = rpn_cls_prob[:, 1] + + # 1. decode boxes + decode_boxes = encode_and_decode.decode_boxes(encoded_boxes=rpn_bbox_pred, + reference_boxes=anchors, + scale_factors=cfgs.ANCHOR_SCALE_FACTORS) + + # 2. clip to img boundaries + decode_boxes = boxes_utils.clip_boxes_to_img_boundaries(decode_boxes=decode_boxes, + img_shape=img_shape) + + # 3. get top N to NMS + if pre_nms_topN > 0: + pre_nms_topN = tf.minimum(pre_nms_topN, tf.shape(decode_boxes)[0], name='avoid_unenough_boxes') + cls_prob, top_k_indices = tf.nn.top_k(cls_prob, k=pre_nms_topN) + decode_boxes = tf.gather(decode_boxes, top_k_indices) + + # 4. NMS + keep = tf.image.non_max_suppression( + boxes=decode_boxes, + scores=cls_prob, + max_output_size=post_nms_topN, + iou_threshold=nms_thresh) + + final_boxes = tf.gather(decode_boxes, keep) + final_probs = tf.gather(cls_prob, keep) + + return final_boxes, final_probs +