pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / build_whole_network.py
1 # -*-coding: utf-8 -*-
2
3 from __future__ import absolute_import, division, print_function
4
5 import os
6 import tensorflow as tf
7 import tensorflow.contrib.slim as slim
8 import numpy as np
9
10 from libs.networks import resnet
11 from libs.networks import mobilenet_v2
12 from libs.box_utils import encode_and_decode
13 from libs.box_utils import boxes_utils
14 from libs.box_utils import anchor_utils
15 from libs.configs import cfgs
16 from libs.losses import losses
17 from libs.box_utils import show_box_in_tensor
18 from libs.detection_oprations.proposal_opr import postprocess_rpn_proposals
19 from libs.detection_oprations.anchor_target_layer_without_boxweight import anchor_target_layer
20 from libs.detection_oprations.proposal_target_layer import proposal_target_layer
21
22
23 class DetectionNetwork(object):
24
25     def __init__(self, base_network_name, is_training):
26
27         self.base_network_name = base_network_name
28         self.is_training = is_training
29         self.num_anchors_per_location = len(cfgs.ANCHOR_SCALES) * len(cfgs.ANCHOR_RATIOS)
30
31     def build_base_network(self, input_img_batch):
32
33         if self.base_network_name.startswith('resnet_v1'):
34             return resnet.resnet_base(input_img_batch, scope_name=self.base_network_name, is_training=self.is_training)
35
36         elif self.base_network_name.startswith('MobilenetV2'):
37             return mobilenet_v2.mobilenetv2_base(input_img_batch, is_training=self.is_training)
38
39         else:
40             raise ValueError('Sry, we only support resnet or mobilenet_v2')
41
42     def postprocess_fastrcnn(self, rois, bbox_ppred, scores, img_shape):
43         '''
44
45         :param rois:[-1, 4]
46         :param bbox_ppred: [-1, (cfgs.Class_num+1) * 4]
47         :param scores: [-1, cfgs.Class_num + 1]
48         :return:
49         '''
50
51         with tf.name_scope('postprocess_fastrcnn'):
52             rois = tf.stop_gradient(rois)
53             scores = tf.stop_gradient(scores)
54             bbox_ppred = tf.reshape(bbox_ppred, [-1, cfgs.CLASS_NUM + 1, 4])
55             bbox_ppred = tf.stop_gradient(bbox_ppred)
56
57             bbox_pred_list = tf.unstack(bbox_ppred, axis=1)
58             score_list = tf.unstack(scores, axis=1)
59
60             allclasses_boxes = []
61             allclasses_scores = []
62             categories = []
63             for i in range(1, cfgs.CLASS_NUM+1):
64
65                 # 1. decode boxes in each class
66                 tmp_encoded_box = bbox_pred_list[i]
67                 tmp_score = score_list[i]
68                 tmp_decoded_boxes = encode_and_decode.decode_boxes(encoded_boxes=tmp_encoded_box,
69                                                                    reference_boxes=rois,
70                                                                    scale_factors=cfgs.ROI_SCALE_FACTORS)
71                 # tmp_decoded_boxes = encode_and_decode.decode_boxes(boxes=rois,
72                 #                                                    deltas=tmp_encoded_box,
73                 #                                                    scale_factor=cfgs.ROI_SCALE_FACTORS)
74
75                 # 2. clip to img boundaries
76                 tmp_decoded_boxes = boxes_utils.clip_boxes_to_img_boundaries(decode_boxes=tmp_decoded_boxes,
77                                                                              img_shape=img_shape)
78
79                 # 3. NMS
80                 keep = tf.image.non_max_suppression(
81                     boxes=tmp_decoded_boxes,
82                     scores=tmp_score,
83                     max_output_size=cfgs.FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,
84                     iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD)
85
86                 perclass_boxes = tf.gather(tmp_decoded_boxes, keep)
87                 perclass_scores = tf.gather(tmp_score, keep)
88
89                 allclasses_boxes.append(perclass_boxes)
90                 allclasses_scores.append(perclass_scores)
91                 categories.append(tf.ones_like(perclass_scores) * i)
92
93             final_boxes = tf.concat(allclasses_boxes, axis=0)
94             final_scores = tf.concat(allclasses_scores, axis=0)
95             final_category = tf.concat(categories, axis=0)
96
97             if self.is_training:
98                 '''
99                 in training. We should show the detecitons in the tensorboard. So we add this.
100                 '''
101                 kept_indices = tf.reshape(tf.where(tf.greater_equal(final_scores, cfgs.SHOW_SCORE_THRSHOLD)), [-1])
102
103                 final_boxes = tf.gather(final_boxes, kept_indices)
104                 final_scores = tf.gather(final_scores, kept_indices)
105                 final_category = tf.gather(final_category, kept_indices)
106
107         return final_boxes, final_scores, final_category
108
109     def roi_pooling(self, feature_maps, rois, img_shape, scope):
110         '''
111         Here use roi warping as roi_pooling
112
113         :param featuremaps_dict: feature map to crop
114         :param rois: shape is [-1, 4]. [x1, y1, x2, y2]
115         :return:
116         '''
117
118         with tf.variable_scope('ROI_Warping_'+scope):
119             img_h, img_w = tf.cast(img_shape[1], tf.float32), tf.cast(img_shape[2], tf.float32)
120             N = tf.shape(rois)[0]
121             x1, y1, x2, y2 = tf.unstack(rois, axis=1)
122
123             normalized_x1 = x1 / img_w
124             normalized_x2 = x2 / img_w
125             normalized_y1 = y1 / img_h
126             normalized_y2 = y2 / img_h
127
128             normalized_rois = tf.transpose(
129                 tf.stack([normalized_y1, normalized_x1, normalized_y2, normalized_x2]), name='get_normalized_rois')
130
131             normalized_rois = tf.stop_gradient(normalized_rois)
132
133             cropped_roi_features = tf.image.crop_and_resize(feature_maps, normalized_rois,
134                                                             box_ind=tf.zeros(shape=[N, ],
135                                                                              dtype=tf.int32),
136                                                             crop_size=[cfgs.ROI_SIZE, cfgs.ROI_SIZE],
137                                                             name='CROP_AND_RESIZE'
138                                                             )
139             roi_features = slim.max_pool2d(cropped_roi_features,
140                                            [cfgs.ROI_POOL_KERNEL_SIZE, cfgs.ROI_POOL_KERNEL_SIZE],
141                                            stride=cfgs.ROI_POOL_KERNEL_SIZE)
142
143         return roi_features
144
145     def build_fastrcnn(self, P_list, rois_list, img_shape):
146
147         with tf.variable_scope('Fast-RCNN'):
148             # 5. ROI Pooling
149             with tf.variable_scope('rois_pooling'):
150                 pooled_features_list = []
151                 for level_name, p, rois in zip(cfgs.LEVLES, P_list, rois_list):  # exclude P6_rois
152                     # p = tf.Print(p, [tf.shape(p)], summarize=10, message=level_name+'SHPAE***')
153                     pooled_features = self.roi_pooling(feature_maps=p, rois=rois, img_shape=img_shape,
154                                                        scope=level_name)
155                     pooled_features_list.append(pooled_features)
156
157                 pooled_features = tf.concat(pooled_features_list, axis=0) # [minibatch_size, H, W, C]
158
159             # 6. inferecne rois in Fast-RCNN to obtain fc_flatten features
160             if self.base_network_name.startswith('resnet'):
161                 fc_flatten = resnet.restnet_head(inputs=pooled_features,
162                                                  is_training=self.is_training,
163                                                  scope_name=self.base_network_name)
164             elif self.base_network_name.startswith('Mobile'):
165                 fc_flatten = mobilenet_v2.mobilenetv2_head(inputs=pooled_features,
166                                                            is_training=self.is_training)
167             else:
168                 raise NotImplementedError('only support resnet and mobilenet')
169
170             # 7. cls and reg in Fast-RCNN
171             with slim.arg_scope([slim.fully_connected], weights_regularizer=slim.l2_regularizer(cfgs.WEIGHT_DECAY)):
172
173                 cls_score = slim.fully_connected(fc_flatten,
174                                                  num_outputs=cfgs.CLASS_NUM+1,
175                                                  weights_initializer=cfgs.INITIALIZER,
176                                                  activation_fn=None, trainable=self.is_training,
177                                                  scope='cls_fc')
178
179                 bbox_pred = slim.fully_connected(fc_flatten,
180                                                  num_outputs=(cfgs.CLASS_NUM+1)*4,
181                                                  weights_initializer=cfgs.BBOX_INITIALIZER,
182                                                  activation_fn=None, trainable=self.is_training,
183                                                  scope='reg_fc')
184                 # for convient. It also produce (cls_num +1) bboxes
185
186                 cls_score = tf.reshape(cls_score, [-1, cfgs.CLASS_NUM+1])
187                 bbox_pred = tf.reshape(bbox_pred, [-1, 4*(cfgs.CLASS_NUM+1)])
188
189                 return bbox_pred, cls_score
190
191     def assign_levels(self, all_rois, labels=None, bbox_targets=None):
192         '''
193
194         :param all_rois:
195         :param labels:
196         :param bbox_targets:
197         :return:
198         '''
199         with tf.name_scope('assign_levels'):
200             # all_rois = tf.Print(all_rois, [tf.shape(all_rois)], summarize=10, message='ALL_ROIS_SHAPE*****')
201             xmin, ymin, xmax, ymax = tf.unstack(all_rois, axis=1)
202
203             h = tf.maximum(0., ymax - ymin)
204             w = tf.maximum(0., xmax - xmin)
205
206             levels = tf.floor(4. + tf.log(tf.sqrt(w * h + 1e-8) / 224.0) / tf.log(2.))  # 4 + log_2(***)
207             # use floor instead of round
208
209             min_level = int(cfgs.LEVLES[0][-1])
210             max_level = min(5, int(cfgs.LEVLES[-1][-1]))
211             levels = tf.maximum(levels, tf.ones_like(levels) * min_level)  # level minimum is 2
212             levels = tf.minimum(levels, tf.ones_like(levels) * max_level)  # level maximum is 5
213
214             levels = tf.stop_gradient(tf.reshape(levels, [-1]))
215
216             def get_rois(levels, level_i, rois, labels, bbox_targets):
217
218                 level_i_indices = tf.reshape(tf.where(tf.equal(levels, level_i)), [-1])
219                 # level_i_indices = tf.Print(level_i_indices, [tf.shape(tf.where(tf.equal(levels, level_i)))[0]], message="SHAPE%d***"%level_i,
220                 #                            summarize=10)
221                 tf.summary.scalar('LEVEL/LEVEL_%d_rois_NUM'%level_i, tf.shape(level_i_indices)[0])
222                 level_i_rois = tf.gather(rois, level_i_indices)
223
224                 if self.is_training:
225                     # If you use low version tensorflow, you may uncomment these code.
226                     level_i_rois = tf.stop_gradient(tf.concat([level_i_rois, [[0, 0, 0., 0.]]], axis=0))
227                     # # to avoid the num of level i rois is 0.0, which will broken the BP in tf
228                     #
229                     level_i_labels = tf.gather(labels, level_i_indices)
230                     level_i_labels = tf.stop_gradient(tf.concat([level_i_labels, [0]], axis=0))
231                     
232                     level_i_targets = tf.gather(bbox_targets, level_i_indices)
233                     level_i_targets = tf.stop_gradient(tf.concat([level_i_targets,
234                                                                     tf.zeros(shape=(1, 4*(cfgs.CLASS_NUM+1)), dtype=tf.float32)],
235                                                                     axis=0))
236                     #level_i_rois = tf.stop_gradient(level_i_rois)
237                     #level_i_labels = tf.gather(labels, level_i_indices)
238
239                     #level_i_targets = tf.gather(bbox_targets, level_i_indices)
240
241                     return level_i_rois, level_i_labels, level_i_targets
242                 else:
243                     return level_i_rois, None, None
244
245             rois_list = []
246             labels_list = []
247             targets_list = []
248             for i in range(min_level, max_level+1):
249                 P_i_rois, P_i_labels, P_i_targets = get_rois(levels, level_i=i, rois=all_rois,
250                                                              labels=labels,
251                                                              bbox_targets=bbox_targets)
252                 rois_list.append(P_i_rois)
253                 labels_list.append(P_i_labels)
254                 targets_list.append(P_i_targets)
255
256             if self.is_training:
257                 all_labels = tf.concat(labels_list, axis=0)
258                 all_targets = tf.concat(targets_list, axis=0)
259                 return rois_list, all_labels, all_targets
260             else:
261                 return rois_list  # [P2_rois, P3_rois, P4_rois, P5_rois] Note: P6 do not assign rois
262
263     def add_anchor_img_smry(self, img, anchors, labels):
264
265         positive_anchor_indices = tf.reshape(tf.where(tf.greater_equal(labels, 1)), [-1])
266         negative_anchor_indices = tf.reshape(tf.where(tf.equal(labels, 0)), [-1])
267
268         positive_anchor = tf.gather(anchors, positive_anchor_indices)
269         negative_anchor = tf.gather(anchors, negative_anchor_indices)
270
271         pos_in_img = show_box_in_tensor.only_draw_boxes(img_batch=img,
272                                                         boxes=positive_anchor)
273         neg_in_img = show_box_in_tensor.only_draw_boxes(img_batch=img,
274                                                         boxes=negative_anchor)
275
276         tf.summary.image('positive_anchor', pos_in_img)
277         tf.summary.image('negative_anchors', neg_in_img)
278
279     def add_roi_batch_img_smry(self, img, rois, labels):
280         positive_roi_indices = tf.reshape(tf.where(tf.greater_equal(labels, 1)), [-1])
281
282         negative_roi_indices = tf.reshape(tf.where(tf.equal(labels, 0)), [-1])
283
284         pos_roi = tf.gather(rois, positive_roi_indices)
285         neg_roi = tf.gather(rois, negative_roi_indices)
286
287
288         pos_in_img = show_box_in_tensor.only_draw_boxes(img_batch=img,
289                                                                boxes=pos_roi)
290         neg_in_img = show_box_in_tensor.only_draw_boxes(img_batch=img,
291                                                                boxes=neg_roi)
292         tf.summary.image('pos_rois', pos_in_img)
293         tf.summary.image('neg_rois', neg_in_img)
294
295     def build_loss(self, rpn_box_pred, rpn_bbox_targets, rpn_cls_score, rpn_labels,
296                    bbox_pred, bbox_targets, cls_score, labels):
297         '''
298
299         :param rpn_box_pred: [-1, 4]
300         :param rpn_bbox_targets: [-1, 4]
301         :param rpn_cls_score: [-1]
302         :param rpn_labels: [-1]
303         :param bbox_pred: [-1, 4*(cls_num+1)]
304         :param bbox_targets: [-1, 4*(cls_num+1)]
305         :param cls_score: [-1, cls_num+1]
306         :param labels: [-1]
307         :return:
308         '''
309         with tf.variable_scope('build_loss') as sc:
310             with tf.variable_scope('rpn_loss'):
311
312                 rpn_bbox_loss = losses.smooth_l1_loss_rpn(bbox_pred=rpn_box_pred,
313                                                           bbox_targets=rpn_bbox_targets,
314                                                           label=rpn_labels,
315                                                           sigma=cfgs.RPN_SIGMA)
316                 # rpn_cls_loss:
317                 # rpn_cls_score = tf.reshape(rpn_cls_score, [-1, 2])
318                 # rpn_labels = tf.reshape(rpn_labels, [-1])
319                 # ensure rpn_labels shape is [-1]
320                 rpn_select = tf.reshape(tf.where(tf.not_equal(rpn_labels, -1)), [-1])
321                 rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score, rpn_select), [-1, 2])
322                 rpn_labels = tf.reshape(tf.gather(rpn_labels, rpn_select), [-1])
323                 rpn_cls_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score,
324                                                                                              labels=rpn_labels))
325
326                 rpn_cls_loss = rpn_cls_loss * cfgs.RPN_CLASSIFICATION_LOSS_WEIGHT
327                 rpn_bbox_loss = rpn_bbox_loss * cfgs.RPN_LOCATION_LOSS_WEIGHT
328
329             with tf.variable_scope('FastRCNN_loss'):
330                 if not cfgs.FAST_RCNN_MINIBATCH_SIZE == -1:
331                     bbox_loss = losses.smooth_l1_loss_rcnn(bbox_pred=bbox_pred,
332                                                            bbox_targets=bbox_targets,
333                                                            label=labels,
334                                                            num_classes=cfgs.CLASS_NUM + 1,
335                                                            sigma=cfgs.FASTRCNN_SIGMA)
336
337                     # cls_score = tf.reshape(cls_score, [-1, cfgs.CLASS_NUM + 1])
338                     # labels = tf.reshape(labels, [-1])
339                     cls_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
340                         logits=cls_score,
341                         labels=labels))  # beacause already sample before
342                 else:
343                     ''' 
344                     applying OHEM here
345                     '''
346                     print(20 * "@@")
347                     print("@@" + 10 * " " + "TRAIN WITH OHEM ...")
348                     print(20 * "@@")
349                     cls_loss, bbox_loss = losses.sum_ohem_loss(
350                         cls_score=cls_score,
351                         label=labels,
352                         bbox_targets=bbox_targets,
353                         bbox_pred=bbox_pred,
354                         num_ohem_samples=128,
355                         num_classes=cfgs.CLASS_NUM + 1)
356                 cls_loss = cls_loss * cfgs.FAST_RCNN_CLASSIFICATION_LOSS_WEIGHT
357                 bbox_loss = bbox_loss * cfgs.FAST_RCNN_LOCATION_LOSS_WEIGHT
358             loss_dict = {
359                 'rpn_cls_loss': rpn_cls_loss,
360                 'rpn_loc_loss': rpn_bbox_loss,
361                 'fastrcnn_cls_loss': cls_loss,
362                 'fastrcnn_loc_loss': bbox_loss
363             }
364         return loss_dict
365
366     def build_whole_detection_network(self, input_img_batch, gtboxes_batch):
367
368         if self.is_training:
369             # ensure shape is [M, 5]
370             gtboxes_batch = tf.reshape(gtboxes_batch, [-1, 5])
371             gtboxes_batch = tf.cast(gtboxes_batch, tf.float32)
372
373         img_shape = tf.shape(input_img_batch)
374
375         # 1. build base network
376         P_list = self.build_base_network(input_img_batch)  # [P2, P3, P4, P5, P6]
377
378         # 2. build rpn
379         with tf.variable_scope('build_rpn',
380                                regularizer=slim.l2_regularizer(cfgs.WEIGHT_DECAY)):
381
382             fpn_cls_score =[]
383             fpn_box_pred = []
384             for level_name, p in zip(cfgs.LEVLES, P_list):
385                 if cfgs.SHARE_HEADS:
386                     reuse_flag = None if level_name==cfgs.LEVLES[0] else True
387                     scope_list=['rpn_conv/3x3', 'rpn_cls_score', 'rpn_bbox_pred']
388                 else:
389                     reuse_flag = None
390                     scope_list= ['rpn_conv/3x3_%s' % level_name, 'rpn_cls_score_%s' % level_name, 'rpn_bbox_pred_%s' % level_name]
391                 rpn_conv3x3 = slim.conv2d(
392                     p, 512, [3, 3],
393                     trainable=self.is_training, weights_initializer=cfgs.INITIALIZER, padding="SAME",
394                     activation_fn=tf.nn.relu,
395                     scope=scope_list[0],
396                     reuse=reuse_flag)
397                 rpn_cls_score = slim.conv2d(rpn_conv3x3, self.num_anchors_per_location*2, [1, 1], stride=1,
398                                             trainable=self.is_training, weights_initializer=cfgs.INITIALIZER,
399                                             activation_fn=None, padding="VALID",
400                                             scope=scope_list[1],
401                                             reuse=reuse_flag)
402                 rpn_box_pred = slim.conv2d(rpn_conv3x3, self.num_anchors_per_location*4, [1, 1], stride=1,
403                                            trainable=self.is_training, weights_initializer=cfgs.BBOX_INITIALIZER,
404                                            activation_fn=None, padding="VALID",
405                                            scope=scope_list[2],
406                                            reuse=reuse_flag)
407                 rpn_box_pred = tf.reshape(rpn_box_pred, [-1, 4])
408                 rpn_cls_score = tf.reshape(rpn_cls_score, [-1, 2])
409
410                 fpn_cls_score.append(rpn_cls_score)
411                 fpn_box_pred.append(rpn_box_pred)
412
413             fpn_cls_score = tf.concat(fpn_cls_score, axis=0, name='fpn_cls_score')
414             fpn_box_pred = tf.concat(fpn_box_pred, axis=0, name='fpn_box_pred')
415             fpn_cls_prob = slim.softmax(fpn_cls_score, scope='fpn_cls_prob')
416
417         # 3. generate_anchors
418         all_anchors = []
419         for i in range(len(cfgs.LEVLES)):
420             level_name, p = cfgs.LEVLES[i], P_list[i]
421
422             p_h, p_w = tf.shape(p)[1], tf.shape(p)[2]
423             featuremap_height = tf.cast(p_h, tf.float32)
424             featuremap_width = tf.cast(p_w, tf.float32)
425             anchors = anchor_utils.make_anchors(base_anchor_size=cfgs.BASE_ANCHOR_SIZE_LIST[i],
426                                                 anchor_scales=cfgs.ANCHOR_SCALES,
427                                                 anchor_ratios=cfgs.ANCHOR_RATIOS,
428                                                 featuremap_height=featuremap_height,
429                                                 featuremap_width=featuremap_width,
430                                                 stride=cfgs.ANCHOR_STRIDE_LIST[i],
431                                                 name="make_anchors_for%s" % level_name)
432             all_anchors.append(anchors)
433         all_anchors = tf.concat(all_anchors, axis=0, name='all_anchors_of_FPN')
434
435         # 4. postprocess rpn proposals. such as: decode, clip, NMS
436         with tf.variable_scope('postprocess_FPN'):
437             rois, roi_scores = postprocess_rpn_proposals(rpn_bbox_pred=fpn_box_pred,
438                                                          rpn_cls_prob=fpn_cls_prob,
439                                                          img_shape=img_shape,
440                                                          anchors=all_anchors,
441                                                          is_training=self.is_training)
442             # rois shape [-1, 4]
443             # +++++++++++++++++++++++++++++++++++++add img smry+++++++++++++++++++++++++++++++++++++++++++++++++++++++
444
445             if self.is_training:
446                 score_gre_05 = tf.reshape(tf.where(tf.greater_equal(roi_scores, 0.5)), [-1])
447                 score_gre_05_rois = tf.gather(rois, score_gre_05)
448                 score_gre_05_score = tf.gather(roi_scores, score_gre_05)
449                 score_gre_05_in_img = show_box_in_tensor.draw_boxes_with_scores(img_batch=input_img_batch,
450                                                                                 boxes=score_gre_05_rois,
451                                                                                 scores=score_gre_05_score)
452                 tf.summary.image('score_greater_05_rois', score_gre_05_in_img)
453             # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
454
455         if self.is_training:
456             with tf.variable_scope('sample_anchors_minibatch'):
457                 fpn_labels, fpn_bbox_targets = \
458                     tf.py_func(
459                         anchor_target_layer,
460                         [gtboxes_batch, img_shape, all_anchors],
461                         [tf.float32, tf.float32])
462                 fpn_bbox_targets = tf.reshape(fpn_bbox_targets, [-1, 4])
463                 fpn_labels = tf.to_int32(fpn_labels, name="to_int32")
464                 fpn_labels = tf.reshape(fpn_labels, [-1])
465                 self.add_anchor_img_smry(input_img_batch, all_anchors, fpn_labels)
466
467             # --------------------------------------add smry-----------------------------------------------------------
468
469             fpn_cls_category = tf.argmax(fpn_cls_prob, axis=1)
470             kept_rpppn = tf.reshape(tf.where(tf.not_equal(fpn_labels, -1)), [-1])
471             fpn_cls_category = tf.gather(fpn_cls_category, kept_rpppn)
472             acc = tf.reduce_mean(tf.to_float(tf.equal(fpn_cls_category,
473                                                       tf.to_int64(tf.gather(fpn_labels, kept_rpppn)))))
474             tf.summary.scalar('ACC/fpn_accuracy', acc)
475
476             with tf.control_dependencies([fpn_labels]):
477                 with tf.variable_scope('sample_RCNN_minibatch'):
478                     rois, labels, bbox_targets = \
479                     tf.py_func(proposal_target_layer,
480                                [rois, gtboxes_batch],
481                                [tf.float32, tf.float32, tf.float32])
482                     rois = tf.reshape(rois, [-1, 4])
483                     labels = tf.to_int32(labels)
484                     labels = tf.reshape(labels, [-1])
485                     bbox_targets = tf.reshape(bbox_targets, [-1, 4*(cfgs.CLASS_NUM+1)])
486                     self.add_roi_batch_img_smry(input_img_batch, rois, labels)
487         if self.is_training:
488             rois_list, labels, bbox_targets = self.assign_levels(all_rois=rois,
489                                                                  labels=labels,
490                                                                  bbox_targets=bbox_targets)
491         else:
492             rois_list = self.assign_levels(all_rois=rois)  # rois_list: [P2_rois, P3_rois, P4_rois, P5_rois]
493
494         # -------------------------------------------------------------------------------------------------------------#
495         #                                            Fast-RCNN                                                         #
496         # -------------------------------------------------------------------------------------------------------------#
497
498         # 5. build Fast-RCNN
499         # rois = tf.Print(rois, [tf.shape(rois)], 'rois shape', summarize=10)
500         bbox_pred, cls_score = self.build_fastrcnn(P_list=P_list, rois_list=rois_list,
501                                                    img_shape=img_shape)
502         # bbox_pred shape: [-1, 4*(cls_num+1)].
503         # cls_score shapeï¼?[-1, cls_num+1]
504
505         cls_prob = slim.softmax(cls_score, 'cls_prob')
506
507
508         # ----------------------------------------------add smry-------------------------------------------------------
509         if self.is_training:
510             cls_category = tf.argmax(cls_prob, axis=1)
511             fast_acc = tf.reduce_mean(tf.to_float(tf.equal(cls_category, tf.to_int64(labels))))
512             tf.summary.scalar('ACC/fast_acc', fast_acc)
513
514         rois = tf.concat(rois_list, axis=0, name='concat_rois')
515         #  6. postprocess_fastrcnn
516         if not self.is_training:
517             return self.postprocess_fastrcnn(rois=rois, bbox_ppred=bbox_pred, scores=cls_prob, img_shape=img_shape)
518         else:
519             '''
520             when trian. We need build Loss
521             '''
522             loss_dict = self.build_loss(rpn_box_pred=fpn_box_pred,
523                                         rpn_bbox_targets=fpn_bbox_targets,
524                                         rpn_cls_score=fpn_cls_score,
525                                         rpn_labels=fpn_labels,
526                                         bbox_pred=bbox_pred,
527                                         bbox_targets=bbox_targets,
528                                         cls_score=cls_score,
529                                         labels=labels)
530
531             final_bbox, final_scores, final_category = self.postprocess_fastrcnn(rois=rois,
532                                                                                  bbox_ppred=bbox_pred,
533                                                                                  scores=cls_prob,
534                                                                                  img_shape=img_shape)
535             return final_bbox, final_scores, final_category, loss_dict
536
537     def get_restorer(self):
538         checkpoint_path = tf.train.latest_checkpoint(os.path.join(cfgs.TRAINED_CKPT, cfgs.VERSION))
539
540         if checkpoint_path != None:
541             restorer = tf.train.Saver()
542             print("model restore from :", checkpoint_path)
543         else:
544             checkpoint_path = cfgs.PRETRAINED_CKPT
545             print("model restore from pretrained mode, path is :", checkpoint_path)
546
547             model_variables = slim.get_model_variables()
548             # for var in model_variables:
549             #     print(var.name)
550             # print(20*"__++__++__")
551
552             def name_in_ckpt_rpn(var):
553                 return var.op.name
554
555             def name_in_ckpt_fastrcnn_head(var):
556                 '''
557                 Fast-RCNN/resnet_v1_50/block4 -->resnet_v1_50/block4
558                 Fast-RCNN/MobilenetV2/** -- > MobilenetV2 **
559                 :param var:
560                 :return:
561                 '''
562                 return '/'.join(var.op.name.split('/')[1:])
563             nameInCkpt_Var_dict = {}
564             for var in model_variables:
565                 if var.name.startswith(self.base_network_name):
566                     var_name_in_ckpt = name_in_ckpt_rpn(var)
567                     nameInCkpt_Var_dict[var_name_in_ckpt] = var
568             restore_variables = nameInCkpt_Var_dict
569             for key, item in restore_variables.items():
570                 print("var_in_graph: ", item.name)
571                 print("var_in_ckpt: ", key)
572                 print(20*"___")
573             restorer = tf.train.Saver(restore_variables)
574             print(20 * "****")
575             print("restore from pretrained_weighs in IMAGE_NET")
576         return restorer, checkpoint_path
577
578     def get_gradients(self, optimizer, loss):
579         '''
580
581         :param optimizer:
582         :param loss:
583         :return:
584
585         return vars and grads that not be fixed
586         '''
587
588         # if cfgs.FIXED_BLOCKS > 0:
589         #     trainable_vars = tf.trainable_variables()
590         #     # trained_vars = slim.get_trainable_variables()
591         #     start_names = [cfgs.NET_NAME + '/block%d'%i for i in range(1, cfgs.FIXED_BLOCKS+1)] + \
592         #                   [cfgs.NET_NAME + '/conv1']
593         #     start_names = tuple(start_names)
594         #     trained_var_list = []
595         #     for var in trainable_vars:
596         #         if not var.name.startswith(start_names):
597         #             trained_var_list.append(var)
598         #     # slim.learning.train()
599         #     grads = optimizer.compute_gradients(loss, var_list=trained_var_list)
600         #     return grads
601         # else:
602         #     return optimizer.compute_gradients(loss)
603         return optimizer.compute_gradients(loss)
604
605     def enlarge_gradients_for_bias(self, gradients):
606
607         final_gradients = []
608         with tf.variable_scope("Gradient_Mult") as scope:
609             for grad, var in gradients:
610                 scale = 1.0
611                 if cfgs.MUTILPY_BIAS_GRADIENT and './biases' in var.name:
612                     scale = scale * cfgs.MUTILPY_BIAS_GRADIENT
613                 if not np.allclose(scale, 1.0):
614                     grad = tf.multiply(grad, scale)
615                 final_gradients.append((grad, var))
616         return final_gradients
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636