pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / detection_oprations / proposal_opr.py
diff --git a/example-apps/PDD/pcb-defect-detection/libs/detection_oprations/proposal_opr.py b/example-apps/PDD/pcb-defect-detection/libs/detection_oprations/proposal_opr.py
new file mode 100755 (executable)
index 0000000..54cf9b8
--- /dev/null
@@ -0,0 +1,66 @@
+# encoding: utf-8
+"""
+@author: zeming li
+@contact: zengarden2009@gmail.com
+"""
+
+from libs.configs import cfgs
+from libs.box_utils import encode_and_decode
+from libs.box_utils import boxes_utils
+import tensorflow as tf
+import numpy as np
+
+
+def postprocess_rpn_proposals(rpn_bbox_pred, rpn_cls_prob, img_shape, anchors, is_training):
+    '''
+
+    :param rpn_bbox_pred: [-1, 4]
+    :param rpn_cls_prob: [-1, 2]
+    :param img_shape:
+    :param anchors:[-1, 4]
+    :param is_training:
+    :return:
+    '''
+
+    if is_training:
+        pre_nms_topN = cfgs.RPN_TOP_K_NMS_TRAIN
+        post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TARIN
+        # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TRAIN
+        # post_nms_topN = pre_nms_topN
+    else:
+        pre_nms_topN = cfgs.RPN_TOP_K_NMS_TEST
+        post_nms_topN = cfgs.RPN_MAXIMUM_PROPOSAL_TEST
+        # pre_nms_topN = cfgs.FPN_TOP_K_PER_LEVEL_TEST
+        # post_nms_topN = pre_nms_topN
+
+    nms_thresh = cfgs.RPN_NMS_IOU_THRESHOLD
+
+    cls_prob = rpn_cls_prob[:, 1]
+
+    # 1. decode boxes
+    decode_boxes = encode_and_decode.decode_boxes(encoded_boxes=rpn_bbox_pred,
+                                                  reference_boxes=anchors,
+                                                  scale_factors=cfgs.ANCHOR_SCALE_FACTORS)
+
+    # 2. clip to img boundaries
+    decode_boxes = boxes_utils.clip_boxes_to_img_boundaries(decode_boxes=decode_boxes,
+                                                            img_shape=img_shape)
+
+    # 3. get top N to NMS
+    if pre_nms_topN > 0:
+        pre_nms_topN = tf.minimum(pre_nms_topN, tf.shape(decode_boxes)[0], name='avoid_unenough_boxes')
+        cls_prob, top_k_indices = tf.nn.top_k(cls_prob, k=pre_nms_topN)
+        decode_boxes = tf.gather(decode_boxes, top_k_indices)
+
+    # 4. NMS
+    keep = tf.image.non_max_suppression(
+        boxes=decode_boxes,
+        scores=cls_prob,
+        max_output_size=post_nms_topN,
+        iou_threshold=nms_thresh)
+
+    final_boxes = tf.gather(decode_boxes, keep)
+    final_probs = tf.gather(cls_prob, keep)
+
+    return final_boxes, final_probs
+