pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / build_whole_network2.py
diff --git a/example-apps/PDD/pcb-defect-detection/libs/networks/build_whole_network2.py b/example-apps/PDD/pcb-defect-detection/libs/networks/build_whole_network2.py
new file mode 100755 (executable)
index 0000000..8281440
--- /dev/null
@@ -0,0 +1,637 @@
+# -*-coding: utf-8 -*-
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+import numpy as np
+
+from libs.networks import resnet
+from libs.networks import mobilenet_v2
+from libs.box_utils import encode_and_decode
+from libs.box_utils import boxes_utils
+from libs.box_utils import anchor_utils
+from libs.configs import cfgs
+from libs.losses import losses
+from libs.box_utils import show_box_in_tensor
+from libs.detection_oprations.proposal_opr import postprocess_rpn_proposals
+from libs.detection_oprations.anchor_target_layer_without_boxweight import anchor_target_layer
+from libs.detection_oprations.proposal_target_layer import proposal_target_layer
+
+
+class DetectionNetwork(object):
+
+    def __init__(self, base_network_name, is_training):
+
+        self.base_network_name = base_network_name
+        self.is_training = is_training
+        self.num_anchors_per_location = len(cfgs.ANCHOR_SCALES) * len(cfgs.ANCHOR_RATIOS)
+
+    def build_base_network(self, input_img_batch):
+
+        if self.base_network_name.startswith('resnet_v1'):
+            return resnet.resnet_base(input_img_batch, scope_name=self.base_network_name, is_training=self.is_training)
+
+        elif self.base_network_name.startswith('MobilenetV2'):
+            return mobilenet_v2.mobilenetv2_base(input_img_batch, is_training=self.is_training)
+
+        else:
+            raise ValueError('Sry, we only support resnet or mobilenet_v2')
+
+    def postprocess_fastrcnn(self, rois, bbox_ppred, scores, img_shape):
+        '''
+
+        :param rois:[-1, 4]
+        :param bbox_ppred: [-1, (cfgs.Class_num+1) * 4]
+        :param scores: [-1, cfgs.Class_num + 1]
+        :return:
+        '''
+
+        with tf.name_scope('postprocess_fastrcnn'):
+            rois = tf.stop_gradient(rois)
+            scores = tf.stop_gradient(scores)
+            bbox_ppred = tf.reshape(bbox_ppred, [-1, cfgs.CLASS_NUM + 1, 4])
+            bbox_ppred = tf.stop_gradient(bbox_ppred)
+
+            bbox_pred_list = tf.unstack(bbox_ppred, axis=1)
+            score_list = tf.unstack(scores, axis=1)
+
+            allclasses_boxes = []
+            allclasses_scores = []
+            categories = []
+            for i in range(1, cfgs.CLASS_NUM+1):
+
+                # 1. decode boxes in each class
+                tmp_encoded_box = bbox_pred_list[i]
+                tmp_score = score_list[i]
+                tmp_decoded_boxes = encode_and_decode.decode_boxes(encoded_boxes=tmp_encoded_box,
+                                                                   reference_boxes=rois,
+                                                                   scale_factors=cfgs.ROI_SCALE_FACTORS)
+                # tmp_decoded_boxes = encode_and_decode.decode_boxes(boxes=rois,
+                #                                                    deltas=tmp_encoded_box,
+                #                                                    scale_factor=cfgs.ROI_SCALE_FACTORS)
+
+                # 2. clip to img boundaries
+                tmp_decoded_boxes = boxes_utils.clip_boxes_to_img_boundaries(decode_boxes=tmp_decoded_boxes,
+                                                                             img_shape=img_shape)
+
+                # 3. NMS
+                keep = tf.image.non_max_suppression(
+                    boxes=tmp_decoded_boxes,
+                    scores=tmp_score,
+                    max_output_size=cfgs.FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,
+                    iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD)
+
+                perclass_boxes = tf.gather(tmp_decoded_boxes, keep)
+                perclass_scores = tf.gather(tmp_score, keep)
+
+                allclasses_boxes.append(perclass_boxes)
+                allclasses_scores.append(perclass_scores)
+                categories.append(tf.ones_like(perclass_scores) * i)
+
+            final_boxes = tf.concat(allclasses_boxes, axis=0)
+            final_scores = tf.concat(allclasses_scores, axis=0)
+            final_category = tf.concat(categories, axis=0)
+
+            if self.is_training:
+                '''
+                in training. We should show the detecitons in the tensorboard. So we add this.
+                '''
+                kept_indices = tf.reshape(tf.where(tf.greater_equal(final_scores, cfgs.SHOW_SCORE_THRSHOLD)), [-1])
+
+                final_boxes = tf.gather(final_boxes, kept_indices)
+                final_scores = tf.gather(final_scores, kept_indices)
+                final_category = tf.gather(final_category, kept_indices)
+
+        return final_boxes, final_scores, final_category
+
+    def roi_pooling(self, feature_maps, rois, img_shape, scope):
+        '''
+        Here use roi warping as roi_pooling
+
+        :param featuremaps_dict: feature map to crop
+        :param rois: shape is [-1, 4]. [x1, y1, x2, y2]
+        :return:
+        '''
+
+        with tf.variable_scope('ROI_Warping_'+scope):
+            img_h, img_w = tf.cast(img_shape[1], tf.float32), tf.cast(img_shape[2], tf.float32)
+            N = tf.shape(rois)[0]
+            x1, y1, x2, y2 = tf.unstack(rois, axis=1)
+
+            normalized_x1 = x1 / img_w
+            normalized_x2 = x2 / img_w
+            normalized_y1 = y1 / img_h
+            normalized_y2 = y2 / img_h
+
+            normalized_rois = tf.transpose(
+                tf.stack([normalized_y1, normalized_x1, normalized_y2, normalized_x2]), name='get_normalized_rois')
+
+            normalized_rois = tf.stop_gradient(normalized_rois)
+
+            cropped_roi_features = tf.image.crop_and_resize(feature_maps, normalized_rois,
+                                                            box_ind=tf.zeros(shape=[N],
+                                                                             dtype=tf.int32),
+                                                            crop_size=[cfgs.ROI_SIZE, cfgs.ROI_SIZE],
+                                                            name='CROP_AND_RESIZE'
+                                                            )
+            roi_features = slim.max_pool2d(cropped_roi_features,
+                                           [cfgs.ROI_POOL_KERNEL_SIZE, cfgs.ROI_POOL_KERNEL_SIZE],
+                                           stride=cfgs.ROI_POOL_KERNEL_SIZE)
+
+        return roi_features
+
+    def build_fastrcnn(self, P_list, rois_list, img_shape):
+
+        with tf.variable_scope('Fast-RCNN'):
+            # 5. ROI Pooling
+            with tf.variable_scope('rois_pooling'):
+                pooled_features_list = []
+                for level_name, p, rois in zip(cfgs.LEVLES, P_list, rois_list):  # exclude P6_rois
+                    # p = tf.Print(p, [tf.shape(p)], summarize=10, message=level_name+'SHPAE***')
+                    pooled_features = self.roi_pooling(feature_maps=p, rois=rois, img_shape=img_shape,
+                                                       scope=level_name)
+                    pooled_features_list.append(pooled_features)
+
+                pooled_features = tf.concat(pooled_features_list, axis=0) # [minibatch_size, H, W, C]
+
+            # 6. inferecne rois in Fast-RCNN to obtain fc_flatten features
+            if self.base_network_name.startswith('resnet'):
+                fc_flatten = resnet.restnet_head(inputs=pooled_features,
+                                                 is_training=self.is_training,
+                                                 scope_name=self.base_network_name)
+            elif self.base_network_name.startswith('Mobile'):
+                fc_flatten = mobilenet_v2.mobilenetv2_head(inputs=pooled_features,
+                                                           is_training=self.is_training)
+            else:
+                raise NotImplementedError('only support resnet and mobilenet')
+
+            # 7. cls and reg in Fast-RCNN
+            with slim.arg_scope([slim.fully_connected], weights_regularizer=slim.l2_regularizer(cfgs.WEIGHT_DECAY)):
+
+                cls_score = slim.fully_connected(fc_flatten,
+                                                 num_outputs=cfgs.CLASS_NUM+1,
+                                                 weights_initializer=cfgs.INITIALIZER,
+                                                 activation_fn=None, trainable=self.is_training,
+                                                 scope='cls_fc')
+
+                bbox_pred = slim.fully_connected(fc_flatten,
+                                                 num_outputs=(cfgs.CLASS_NUM+1)*4,
+                                                 weights_initializer=cfgs.BBOX_INITIALIZER,
+                                                 activation_fn=None, trainable=self.is_training,
+                                                 scope='reg_fc')
+                # for convient. It also produce (cls_num +1) bboxes
+
+                cls_score = tf.reshape(cls_score, [-1, cfgs.CLASS_NUM+1])
+                bbox_pred = tf.reshape(bbox_pred, [-1, 4*(cfgs.CLASS_NUM+1)])
+
+        return bbox_pred, cls_score
+
+    def assign_levels(self, all_rois, labels=None, bbox_targets=None):
+        '''
+
+        :param all_rois:
+        :param labels:
+        :param bbox_targets:
+        :return:
+        '''
+        with tf.name_scope('assign_levels'):
+            # all_rois = tf.Print(all_rois, [tf.shape(all_rois)], summarize=10, message='ALL_ROIS_SHAPE*****')
+            xmin, ymin, xmax, ymax = tf.unstack(all_rois, axis=1)
+
+            h = tf.maximum(0., ymax - ymin)
+            w = tf.maximum(0., xmax - xmin)
+
+            levels = tf.floor(4. + tf.log(tf.sqrt(w * h + 1e-8) / 224.0) / tf.log(2.))  # 4 + log_2(***)
+            # use floor instead of round
+
+            min_level = int(cfgs.LEVLES[0][-1])
+            max_level = min(5, int(cfgs.LEVLES[-1][-1]))
+            levels = tf.maximum(levels, tf.ones_like(levels) * min_level)  # level minimum is 2
+            levels = tf.minimum(levels, tf.ones_like(levels) * max_level)  # level maximum is 5
+
+            levels = tf.stop_gradient(tf.reshape(levels, [-1]))
+
+            def get_rois(levels, level_i, rois, labels, bbox_targets):
+
+                level_i_indices = tf.reshape(tf.where(tf.equal(levels, level_i)), [-1])
+                tf.summary.scalar('LEVEL/LEVEL_%d_rois_NUM'%level_i, tf.shape(level_i_indices)[0])
+                level_i_rois = tf.gather(rois, level_i_indices)
+
+                # if self.is_training:
+                #     level_i_rois = tf.stop_gradient(tf.concat([level_i_rois, [[0, 0, 0., 0.]]], axis=0))
+                #     # to avoid the num of level i rois is 0.0, which will broken the BP in tf
+                #
+                #     level_i_labels = tf.gather(labels, level_i_indices)
+                #     level_i_labels = tf.stop_gradient(tf.concat([level_i_labels, [0]], axis=0))
+                #
+                #     level_i_targets = tf.gather(bbox_targets, level_i_indices)
+                #     level_i_targets = tf.stop_gradient(tf.concat([level_i_targets,
+                #                                                   tf.zeros(shape=(1, 4*(cfgs.CLASS_NUM+1)), dtype=tf.float32)],
+                #                                                  axis=0))
+                #
+                #     return level_i_rois, level_i_labels, level_i_targets
+                if self.is_training:
+                    level_i_labels = tf.gather(labels, level_i_indices)
+                    level_i_targets = tf.gather(bbox_targets, level_i_indices)
+                    return level_i_rois, level_i_labels, level_i_targets
+                else:
+                    return level_i_rois, None, None
+
+            rois_list = []
+            labels_list = []
+            targets_list = []
+            for i in range(min_level, max_level+1):
+                P_i_rois, P_i_labels, P_i_targets = get_rois(levels, level_i=i, rois=all_rois,
+                                                             labels=labels,
+                                                             bbox_targets=bbox_targets)
+                rois_list.append(P_i_rois)
+                labels_list.append(P_i_labels)
+                targets_list.append(P_i_targets)
+
+            if self.is_training:
+                all_labels = tf.concat(labels_list, axis=0)
+                all_targets = tf.concat(targets_list, axis=0)
+
+                return rois_list, \
+                       tf.stop_gradient(all_labels), \
+                       tf.stop_gradient(all_targets)
+            else:
+                return rois_list  # [P2_rois, P3_rois, P4_rois, P5_rois] Note: P6 do not assign rois
+    # ---------------------------------------------------------------------------------------
+
+    def add_anchor_img_smry(self, img, anchors, labels):
+
+        positive_anchor_indices = tf.reshape(tf.where(tf.greater_equal(labels, 1)), [-1])
+        negative_anchor_indices = tf.reshape(tf.where(tf.equal(labels, 0)), [-1])
+
+        positive_anchor = tf.gather(anchors, positive_anchor_indices)
+        negative_anchor = tf.gather(anchors, negative_anchor_indices)
+
+        pos_in_img = show_box_in_tensor.only_draw_boxes(img_batch=img,
+                                                        boxes=positive_anchor)
+        neg_in_img = show_box_in_tensor.only_draw_boxes(img_batch=img,
+                                                        boxes=negative_anchor)
+
+        tf.summary.image('positive_anchor', pos_in_img)
+        tf.summary.image('negative_anchors', neg_in_img)
+
+    def add_roi_batch_img_smry(self, img, rois, labels):
+        positive_roi_indices = tf.reshape(tf.where(tf.greater_equal(labels, 1)), [-1])
+
+        negative_roi_indices = tf.reshape(tf.where(tf.equal(labels, 0)), [-1])
+
+        pos_roi = tf.gather(rois, positive_roi_indices)
+        neg_roi = tf.gather(rois, negative_roi_indices)
+
+
+        pos_in_img = show_box_in_tensor.only_draw_boxes(img_batch=img,
+                                                               boxes=pos_roi)
+        neg_in_img = show_box_in_tensor.only_draw_boxes(img_batch=img,
+                                                               boxes=neg_roi)
+        tf.summary.image('pos_rois', pos_in_img)
+        tf.summary.image('neg_rois', neg_in_img)
+
+    def build_loss(self, rpn_box_pred, rpn_bbox_targets, rpn_cls_score, rpn_labels,
+                   bbox_pred, bbox_targets, cls_score, labels):
+        '''
+
+        :param rpn_box_pred: [-1, 4]
+        :param rpn_bbox_targets: [-1, 4]
+        :param rpn_cls_score: [-1]
+        :param rpn_labels: [-1]
+        :param bbox_pred: [-1, 4*(cls_num+1)]
+        :param bbox_targets: [-1, 4*(cls_num+1)]
+        :param cls_score: [-1, cls_num+1]
+        :param labels: [-1]
+        :return:
+        '''
+        with tf.variable_scope('build_loss') as sc:
+            with tf.variable_scope('rpn_loss'):
+
+                rpn_bbox_loss = losses.smooth_l1_loss_rpn(bbox_pred=rpn_box_pred,
+                                                          bbox_targets=rpn_bbox_targets,
+                                                          label=rpn_labels,
+                                                          sigma=cfgs.RPN_SIGMA)
+                # rpn_cls_loss:
+                # rpn_cls_score = tf.reshape(rpn_cls_score, [-1, 2])
+                # rpn_labels = tf.reshape(rpn_labels, [-1])
+                # ensure rpn_labels shape is [-1]
+                rpn_select = tf.reshape(tf.where(tf.not_equal(rpn_labels, -1)), [-1])
+                rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score, rpn_select), [-1, 2])
+                rpn_labels = tf.reshape(tf.gather(rpn_labels, rpn_select), [-1])
+                rpn_cls_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score,
+                                                                                             labels=rpn_labels))
+
+                rpn_cls_loss = rpn_cls_loss * cfgs.RPN_CLASSIFICATION_LOSS_WEIGHT
+                rpn_bbox_loss = rpn_bbox_loss * cfgs.RPN_LOCATION_LOSS_WEIGHT
+
+            with tf.variable_scope('FastRCNN_loss'):
+                if not cfgs.FAST_RCNN_MINIBATCH_SIZE == -1:
+                    bbox_loss = losses.smooth_l1_loss_rcnn(bbox_pred=bbox_pred,
+                                                           bbox_targets=bbox_targets,
+                                                           label=labels,
+                                                           num_classes=cfgs.CLASS_NUM + 1,
+                                                           sigma=cfgs.FASTRCNN_SIGMA)
+
+                    # cls_score = tf.reshape(cls_score, [-1, cfgs.CLASS_NUM + 1])
+                    # labels = tf.reshape(labels, [-1])
+                    cls_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
+                        logits=cls_score,
+                        labels=labels))  # beacause already sample before
+                else:
+                    ''' 
+                    applying OHEM here
+                    '''
+                    print(20 * "@@")
+                    print("@@" + 10 * " " + "TRAIN WITH OHEM ...")
+                    print(20 * "@@")
+                    cls_loss = bbox_loss = losses.sum_ohem_loss(
+                        cls_score=cls_score,
+                        label=labels,
+                        bbox_targets=bbox_targets,
+                        nr_ohem_sampling=128,
+                        nr_classes=cfgs.CLASS_NUM + 1)
+                cls_loss = cls_loss * cfgs.FAST_RCNN_CLASSIFICATION_LOSS_WEIGHT
+                bbox_loss = bbox_loss * cfgs.FAST_RCNN_LOCATION_LOSS_WEIGHT
+            loss_dict = {
+                'rpn_cls_loss': rpn_cls_loss,
+                'rpn_loc_loss': rpn_bbox_loss,
+                'fastrcnn_cls_loss': cls_loss,
+                'fastrcnn_loc_loss': bbox_loss
+            }
+        return loss_dict
+
+
+    def build_whole_detection_network(self, input_img_batch, gtboxes_batch):
+
+        if self.is_training:
+            # ensure shape is [M, 5]
+            gtboxes_batch = tf.reshape(gtboxes_batch, [-1, 5])
+            gtboxes_batch = tf.cast(gtboxes_batch, tf.float32)
+
+        img_shape = tf.shape(input_img_batch)
+
+        # 1. build base network
+        P_dict = self.build_base_network(input_img_batch)  # [P2, P3, P4, P5, P6]
+
+        # 2. build fpn  by build rpn for each level
+        with tf.variable_scope("build_FPN", regularizer=slim.l2_regularizer(cfgs.WEIGHT_DECAY)):
+            fpn_box_delta = {}
+            fpn_cls_score = {}
+            fpn_cls_prob = {}
+            for key in cfgs.LEVLES:
+                if cfgs.SHARE_HEADS:
+                    reuse_flag = None if key == cfgs.LEVLES[0] else True
+                    scope_list = ['fpn_conv/3x3', 'fpn_cls_score', 'fpn_bbox_pred']
+                else:
+                    reuse_flag = None
+                    scope_list= ['fpn_conv/3x3_%s' % key, 'fpn_cls_score_%s' % key, 'fpn_bbox_pred_%s' % key]
+                rpn_conv3x3 = slim.conv2d(
+                    P_dict[key], 512, [3, 3],
+                    trainable=self.is_training, weights_initializer=cfgs.INITIALIZER, padding="SAME",
+                    activation_fn=tf.nn.relu,
+                    scope=scope_list[0],
+                    reuse=reuse_flag)
+                rpn_cls_score = slim.conv2d(rpn_conv3x3, self.num_anchors_per_location*2, [1, 1], stride=1,
+                                            trainable=self.is_training, weights_initializer=cfgs.INITIALIZER,
+                                            activation_fn=None, padding="VALID",
+                                            scope=scope_list[1],
+                                            reuse=reuse_flag)
+                rpn_box_delta = slim.conv2d(rpn_conv3x3, self.num_anchors_per_location*4, [1, 1], stride=1,
+                                            trainable=self.is_training, weights_initializer=cfgs.BBOX_INITIALIZER,
+                                            activation_fn=None, padding="VALID",
+                                            scope=scope_list[2],
+                                            reuse=reuse_flag)
+                fpn_box_delta[key] = tf.reshape(rpn_box_delta, [-1, 4])
+                fpn_cls_score[key] = tf.reshape(rpn_cls_score, [-1, 2])
+                fpn_cls_prob[key] = slim.softmax(fpn_cls_score[key])
+
+        # 3. generate anchors for fpn. (by generate for each level)
+        anchors_dict = {}
+        anchor_list = []
+        with tf.name_scope("generate_FPN_anchors"):
+            for key in cfgs.LEVLES:
+                p_h, p_w = tf.to_float(tf.shape(P_dict[key])[1]), tf.to_float(tf.shape(P_dict[key])[2])
+                id_ = int(key[-1]) - int(cfgs.LEVLES[0][-1]) # such as : 2-2, 3-3
+                tmp_anchors = anchor_utils.make_anchors(base_anchor_size=cfgs.BASE_ANCHOR_SIZE_LIST[id_],
+                                                        anchor_scales=cfgs.ANCHOR_SCALES,
+                                                        anchor_ratios=cfgs.ANCHOR_RATIOS,
+                                                        stride=cfgs.ANCHOR_STRIDE_LIST[id_],
+                                                        featuremap_height=p_h,
+                                                        featuremap_width=p_w,
+                                                        name='%s_make_anchors' % key)
+                anchors_dict[key] = tmp_anchors
+                anchor_list.append(tmp_anchors)
+        all_anchors = tf.concat(anchor_list, axis=0)
+
+        # 4. postprocess fpn proposals. such as: decode, clip, NMS
+        #    Need to Note: Here we NMS for each level instead of NMS for all anchors.
+        rois_list = []
+        rois_scores_list = []
+        with tf.name_scope("postproces_fpn"):
+            for key in cfgs.LEVLES:
+                tmp_rois, tmp_roi_scores = postprocess_rpn_proposals(rpn_bbox_pred=fpn_box_delta[key],
+                                                                     rpn_cls_prob=fpn_cls_prob[key],
+                                                                     img_shape=img_shape,
+                                                                     anchors=anchors_dict[key],
+                                                                     is_training=self.is_training)
+                rois_list.append(tmp_rois)
+                rois_scores_list.append(tmp_roi_scores)
+            allrois = tf.concat(rois_list, axis=0)
+            allrois_scores = tf.concat(rois_scores_list, axis=0)
+            fpn_topk = cfgs.FPN_TOP_K_PER_LEVEL_TRAIN if self.is_training else cfgs.FPN_TOP_K_PER_LEVEL_TEST
+            topk = tf.minimum(fpn_topk, tf.shape(allrois)[0])
+
+            rois_scores, topk_indices = tf.nn.top_k(allrois_scores, k=topk)
+
+            rois = tf.stop_gradient(tf.gather(allrois, topk_indices))
+            rois_scores = tf.stop_gradient(rois_scores)
+
+            # +++++++++++++++++++++++++++++++++++++add img smry+++++++++++++++++++++++++++++++++++++++++++++++++++++++
+            if self.is_training:
+                score_gre_05 = tf.reshape(tf.where(tf.greater_equal(rois_scores, 0.5)), [-1])
+                score_gre_05_rois = tf.gather(rois, score_gre_05)
+                score_gre_05_score = tf.gather(rois_scores, score_gre_05)
+                score_gre_05_in_img = show_box_in_tensor.draw_boxes_with_scores(img_batch=input_img_batch,
+                                                                                boxes=score_gre_05_rois,
+                                                                                scores=score_gre_05_score)
+                tf.summary.image('score_greater_05_rois', score_gre_05_in_img)
+            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+        # sample for fpn. We should concat all the anchors
+        if self.is_training:
+            with tf.variable_scope('sample_anchors_minibatch'):
+                fpn_labels, fpn_bbox_targets = \
+                    tf.py_func(
+                        anchor_target_layer,
+                        [gtboxes_batch, img_shape, all_anchors],
+                        [tf.float32, tf.float32])
+                fpn_bbox_targets = tf.reshape(fpn_bbox_targets, [-1, 4])
+                fpn_labels = tf.to_int32(fpn_labels, name="to_int32")
+                fpn_labels = tf.reshape(fpn_labels, [-1])
+                self.add_anchor_img_smry(input_img_batch, all_anchors, fpn_labels)
+
+            with tf.control_dependencies([fpn_labels]):
+                with tf.variable_scope('sample_RCNN_minibatch'):
+                    rois, labels, bbox_targets = \
+                    tf.py_func(proposal_target_layer,
+                               [rois, gtboxes_batch],
+                               [tf.float32, tf.float32, tf.float32])
+                    rois = tf.reshape(rois, [-1, 4])
+                    labels = tf.to_int32(labels)
+                    labels = tf.reshape(labels, [-1])
+                    bbox_targets = tf.reshape(bbox_targets, [-1, 4*(cfgs.CLASS_NUM+1)])
+                    self.add_roi_batch_img_smry(input_img_batch, rois, labels)
+        if self.is_training:
+            rois_list, labels, bbox_targets = self.assign_levels(all_rois=rois,
+                                                                 labels=labels,
+                                                                 bbox_targets=bbox_targets)
+        else:
+            rois_list = self.assign_levels(all_rois=rois)  # rois_list: [P2_rois, P3_rois, P4_rois, P5_rois]
+
+        # -------------------------------------------------------------------------------------------------------------#
+        #                                            Fast-RCNN                                                         #
+        # -------------------------------------------------------------------------------------------------------------#
+
+        # 5. build Fast-RCNN
+        # rois = tf.Print(rois, [tf.shape(rois)], 'rois shape', summarize=10)
+        bbox_pred, cls_score = self.build_fastrcnn(P_list=[P_dict[key] for key in cfgs.LEVLES],
+                                                   rois_list=rois_list,
+                                                   img_shape=img_shape)
+        # bbox_pred shape: [-1, 4*(cls_num+1)].
+        # cls_score shape: [-1, cls_num+1]
+
+        cls_prob = slim.softmax(cls_score, 'cls_prob')
+
+
+        # ----------------------------------------------add smry-------------------------------------------------------
+        if self.is_training:
+            cls_category = tf.argmax(cls_prob, axis=1)
+            fast_acc = tf.reduce_mean(tf.to_float(tf.equal(cls_category, tf.to_int64(labels))))
+            tf.summary.scalar('ACC/fast_acc', fast_acc)
+
+        rois = tf.concat(rois_list, axis=0, name='concat_rois')
+        #  6. postprocess_fastrcnn
+        if not self.is_training:
+            return self.postprocess_fastrcnn(rois=rois, bbox_ppred=bbox_pred, scores=cls_prob, img_shape=img_shape)
+        else:
+            '''
+            when trian. We need build Loss
+            '''
+            loss_dict = self.build_loss(rpn_box_pred=tf.concat([fpn_box_delta[key] for key in cfgs.LEVLES], axis=0),
+                                        rpn_bbox_targets=fpn_bbox_targets,
+                                        rpn_cls_score=tf.concat([fpn_cls_score[key] for key in cfgs.LEVLES], axis=0),
+                                        rpn_labels=fpn_labels,
+                                        bbox_pred=bbox_pred,
+                                        bbox_targets=bbox_targets,
+                                        cls_score=cls_score,
+                                        labels=labels)
+
+            final_bbox, final_scores, final_category = self.postprocess_fastrcnn(rois=rois,
+                                                                                 bbox_ppred=bbox_pred,
+                                                                                 scores=cls_prob,
+                                                                                 img_shape=img_shape)
+            return final_bbox, final_scores, final_category, loss_dict
+
+    def get_restorer(self):
+        checkpoint_path = tf.train.latest_checkpoint(os.path.join(cfgs.TRAINED_CKPT, cfgs.VERSION))
+
+        if checkpoint_path != None:
+            restorer = tf.train.Saver()
+            print("model restore from :", checkpoint_path)
+        else:
+            checkpoint_path = cfgs.PRETRAINED_CKPT
+            print("model restore from pretrained mode, path is :", checkpoint_path)
+
+            model_variables = slim.get_model_variables()
+            # for var in model_variables:
+            #     print(var.name)
+            # print(20*"__++__++__")
+
+            def name_in_ckpt_rpn(var):
+                return var.op.name
+
+            def name_in_ckpt_fastrcnn_head(var):
+                '''
+                Fast-RCNN/resnet_v1_50/block4 -->resnet_v1_50/block4
+                Fast-RCNN/MobilenetV2/** -- > MobilenetV2 **
+                :param var:
+                :return:
+                '''
+                return '/'.join(var.op.name.split('/')[1:])
+            nameInCkpt_Var_dict = {}
+            for var in model_variables:
+                if var.name.startswith(self.base_network_name):
+                    var_name_in_ckpt = name_in_ckpt_rpn(var)
+                    nameInCkpt_Var_dict[var_name_in_ckpt] = var
+            restore_variables = nameInCkpt_Var_dict
+            for key, item in restore_variables.items():
+                print("var_in_graph: ", item.name)
+                print("var_in_ckpt: ", key)
+                print(20*"___")
+            restorer = tf.train.Saver(restore_variables)
+            print(20 * "****")
+            print("restore from pretrained_weighs in IMAGE_NET")
+        return restorer, checkpoint_path
+
+    def get_gradients(self, optimizer, loss):
+        '''
+
+        :param optimizer:
+        :param loss:
+        :return:
+
+        return vars and grads that not be fixed
+        '''
+
+        # if cfgs.FIXED_BLOCKS > 0:
+        #     trainable_vars = tf.trainable_variables()
+        #     # trained_vars = slim.get_trainable_variables()
+        #     start_names = [cfgs.NET_NAME + '/block%d'%i for i in range(1, cfgs.FIXED_BLOCKS+1)] + \
+        #                   [cfgs.NET_NAME + '/conv1']
+        #     start_names = tuple(start_names)
+        #     trained_var_list = []
+        #     for var in trainable_vars:
+        #         if not var.name.startswith(start_names):
+        #             trained_var_list.append(var)
+        #     # slim.learning.train()
+        #     grads = optimizer.compute_gradients(loss, var_list=trained_var_list)
+        #     return grads
+        # else:
+        #     return optimizer.compute_gradients(loss)
+        return optimizer.compute_gradients(loss)
+
+    def enlarge_gradients_for_bias(self, gradients):
+
+        final_gradients = []
+        with tf.variable_scope("Gradient_Mult") as scope:
+            for grad, var in gradients:
+                scale = 1.0
+                if cfgs.MUTILPY_BIAS_GRADIENT and './biases' in var.name:
+                    scale = scale * cfgs.MUTILPY_BIAS_GRADIENT
+                if not np.allclose(scale, 1.0):
+                    grad = tf.multiply(grad, scale)
+                final_gradients.append((grad, var))
+        return final_gradients
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+