b329715d58f9fc26b82abb1d99ff2fbe8c54d679
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / tools / train_with_placeholder.py
1 # -*- coding:utf-8 -*-
2
3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
6
7 import tensorflow as tf
8 import tensorflow.contrib.slim as slim
9 import os, sys
10 sys.path.append("../")
11 sys.path.append("../data/lib_coco")
12 sys.path.append('../data/lib_coco/PythonAPI/')
13
14 import numpy as np
15 import time
16
17 from libs.configs import cfgs
18 from libs.networks import build_whole_network
19 from data.io import image_preprocess
20 from libs.box_utils import show_box_in_tensor
21 from help_utils import tools
22 from data.lib_coco.get_coco_next_batch import next_img
23
24
25 os.environ["CUDA_VISIBLE_DEVICES"] = cfgs.GPU_GROUP
26
27
28 def preprocess_img(img_plac, gtbox_plac):
29     '''
30
31     :param img_plac: [H, W, 3] uint 8 img. In RGB.
32     :param gtbox_plac: shape of [-1, 5]. [xmin, ymin, xmax, ymax, label]
33     :return:
34     '''
35
36     img = tf.cast(img_plac, tf.float32)
37
38     # gtboxes_and_label = tf.cast(gtbox_plac, tf.float32)
39     img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img,
40                                                                 gtboxes_and_label=gtbox_plac,
41                                                                 target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
42                                                                 length_limitation=cfgs.IMG_MAX_LENGTH)
43     img, gtboxes_and_label = image_preprocess.random_flip_left_right(img_tensor=img,
44                                                                      gtboxes_and_label=gtboxes_and_label)
45     img = img - tf.constant([[cfgs.PIXEL_MEAN]])
46     img_batch = tf.expand_dims(img, axis=0)
47
48     # gtboxes_and_label = tf.Print(gtboxes_and_label, [tf.shape(gtboxes_and_label)], message='gtbox shape')
49     return img_batch, gtboxes_and_label
50
51 def train():
52
53     faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
54                                                        is_training=True)
55
56     with tf.name_scope('get_batch'):
57         img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3])
58         gtbox_plac = tf.placeholder(dtype=tf.int32, shape=[None, 5])
59
60         img_batch, gtboxes_and_label = preprocess_img(img_plac, gtbox_plac)
61         # gtboxes_and_label = tf.reshape(gtboxes_and_label_batch, [-1, 5])
62
63     biases_regularizer = tf.no_regularizer
64     weights_regularizer = tf.contrib.layers.l2_regularizer(cfgs.WEIGHT_DECAY)
65
66     # list as many types of layers as possible, even if they are not used now
67     with slim.arg_scope([slim.conv2d, slim.conv2d_in_plane, \
68                          slim.conv2d_transpose, slim.separable_conv2d, slim.fully_connected],
69                         weights_regularizer=weights_regularizer,
70                         biases_regularizer=biases_regularizer,
71                         biases_initializer=tf.constant_initializer(0.0)):
72         final_bbox, final_scores, final_category, loss_dict = faster_rcnn.build_whole_detection_network(
73             input_img_batch=img_batch,
74             gtboxes_batch=gtboxes_and_label)
75
76     # ----------------------------------------------------------------------------------------------------build loss
77     weight_decay_loss = 0 # tf.add_n(slim.losses.get_regularization_losses())
78     rpn_location_loss = loss_dict['rpn_loc_loss']
79     rpn_cls_loss = loss_dict['rpn_cls_loss']
80     rpn_total_loss = rpn_location_loss + rpn_cls_loss
81
82     fastrcnn_cls_loss = loss_dict['fastrcnn_cls_loss']
83     fastrcnn_loc_loss = loss_dict['fastrcnn_loc_loss']
84     fastrcnn_total_loss = fastrcnn_cls_loss + fastrcnn_loc_loss
85
86     total_loss = rpn_total_loss + fastrcnn_total_loss + weight_decay_loss
87     # ____________________________________________________________________________________________________build loss
88
89
90
91     # ---------------------------------------------------------------------------------------------------add summary
92     tf.summary.scalar('RPN_LOSS/cls_loss', rpn_cls_loss)
93     tf.summary.scalar('RPN_LOSS/location_loss', rpn_location_loss)
94     tf.summary.scalar('RPN_LOSS/rpn_total_loss', rpn_total_loss)
95
96     tf.summary.scalar('FAST_LOSS/fastrcnn_cls_loss', fastrcnn_cls_loss)
97     tf.summary.scalar('FAST_LOSS/fastrcnn_location_loss', fastrcnn_loc_loss)
98     tf.summary.scalar('FAST_LOSS/fastrcnn_total_loss', fastrcnn_total_loss)
99
100     tf.summary.scalar('LOSS/total_loss', total_loss)
101     tf.summary.scalar('LOSS/regular_weights', weight_decay_loss)
102
103     gtboxes_in_img = show_box_in_tensor.draw_boxes_with_categories(img_batch=img_batch,
104                                                                    boxes=gtboxes_and_label[:, :-1],
105                                                                    labels=gtboxes_and_label[:, -1])
106     if cfgs.ADD_BOX_IN_TENSORBOARD:
107         detections_in_img = show_box_in_tensor.draw_boxes_with_categories_and_scores(img_batch=img_batch,
108                                                                                      boxes=final_bbox,
109                                                                                      labels=final_category,
110                                                                                      scores=final_scores)
111         tf.summary.image('Compare/final_detection', detections_in_img)
112     tf.summary.image('Compare/gtboxes', gtboxes_in_img)
113
114     # ___________________________________________________________________________________________________add summary
115
116     global_step = slim.get_or_create_global_step()
117     lr = tf.train.piecewise_constant(global_step,
118                                      boundaries=[np.int64(cfgs.DECAY_STEP[0]), np.int64(cfgs.DECAY_STEP[1])],
119                                      values=[cfgs.LR, cfgs.LR / 10., cfgs.LR / 100.])
120     tf.summary.scalar('lr', lr)
121     optimizer = tf.train.MomentumOptimizer(lr, momentum=cfgs.MOMENTUM)
122
123     # ---------------------------------------------------------------------------------------------compute gradients
124     gradients = faster_rcnn.get_gradients(optimizer, total_loss)
125
126     # enlarge_gradients for bias
127     if cfgs.MUTILPY_BIAS_GRADIENT:
128         gradients = faster_rcnn.enlarge_gradients_for_bias(gradients)
129
130     if cfgs.GRADIENT_CLIPPING_BY_NORM:
131         with tf.name_scope('clip_gradients_YJR'):
132             gradients = slim.learning.clip_gradient_norms(gradients,
133                                                           cfgs.GRADIENT_CLIPPING_BY_NORM)
134     # _____________________________________________________________________________________________compute gradients
135
136
137
138     # train_op
139     train_op = optimizer.apply_gradients(grads_and_vars=gradients,
140                                          global_step=global_step)
141     summary_op = tf.summary.merge_all()
142     init_op = tf.group(
143         tf.global_variables_initializer(),
144         tf.local_variables_initializer()
145     )
146
147     restorer, restore_ckpt = faster_rcnn.get_restorer()
148     saver = tf.train.Saver(max_to_keep=30)
149
150     config = tf.ConfigProto()
151     config.gpu_options.allow_growth = True
152
153     with tf.Session(config=config) as sess:
154         sess.run(init_op)
155         if not restorer is None:
156             restorer.restore(sess, restore_ckpt)
157             print('restore model')
158
159         summary_path = os.path.join(cfgs.SUMMARY_PATH, cfgs.VERSION)
160         tools.mkdir(summary_path)
161         summary_writer = tf.summary.FileWriter(summary_path, graph=sess.graph)
162
163         for step in range(cfgs.MAX_ITERATION):
164
165             img_id, img, gt_info = next_img(step=step)
166             training_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
167
168             if step % cfgs.SHOW_TRAIN_INFO_INTE != 0 and step % cfgs.SMRY_ITER != 0:
169                 _, global_stepnp = sess.run([train_op, global_step],
170                                             feed_dict={img_plac: img,
171                                                        gtbox_plac: gt_info}
172                                             )
173
174             else:
175                 if step % cfgs.SHOW_TRAIN_INFO_INTE == 0 and step % cfgs.SMRY_ITER != 0:
176                     start = time.time()
177
178                     _, global_stepnp, rpnLocLoss, rpnClsLoss, rpnTotalLoss, \
179                     fastrcnnLocLoss, fastrcnnClsLoss, fastrcnnTotalLoss, totalLoss = \
180                         sess.run(
181                             [train_op, global_step, rpn_location_loss, rpn_cls_loss, rpn_total_loss,
182                              fastrcnn_loc_loss, fastrcnn_cls_loss, fastrcnn_total_loss, total_loss],
183                         feed_dict={img_plac: img,
184                                    gtbox_plac: gt_info})
185                     end = time.time()
186                     print(""" {}: step{}    image_name:{} |\t
187                               rpn_loc_loss:{} |\t rpn_cla_loss:{} |\t rpn_total_loss:{} |
188                               fast_rcnn_loc_loss:{} |\t fast_rcnn_cla_loss:{} |\t fast_rcnn_total_loss:{} |
189                               total_loss:{} |\t per_cost_time:{}s""" \
190                           .format(training_time, global_stepnp, str(img_id), rpnLocLoss, rpnClsLoss,
191                                   rpnTotalLoss, fastrcnnLocLoss, fastrcnnClsLoss, fastrcnnTotalLoss, totalLoss,
192                                   (end - start)))
193                 else:
194                     if step % cfgs.SMRY_ITER == 0:
195                         _, global_stepnp, summary_str = sess.run([train_op, global_step, summary_op],
196                                                                  feed_dict={img_plac: img,
197                                                                             gtbox_plac: gt_info}
198                                                                  )
199                         summary_writer.add_summary(summary_str, global_stepnp)
200                         summary_writer.flush()
201
202             if (step > 0 and step % cfgs.SAVE_WEIGHTS_INTE == 0) or (step == cfgs.MAX_ITERATION - 1):
203
204                 save_dir = os.path.join(cfgs.TRAINED_CKPT, cfgs.VERSION)
205                 if not os.path.exists(save_dir):
206                     os.mkdir(save_dir)
207
208                 save_ckpt = os.path.join(save_dir, 'voc_' + str(global_stepnp) + 'model.ckpt')
209                 saver.save(sess, save_ckpt)
210                 print(' weights had been saved')
211
212
213 if __name__ == '__main__':
214
215     train()
216
217 #
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233