pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / tools / train.py
1 # -*- coding:utf-8 -*-
2
3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
6
7 import tensorflow as tf
8 import tensorflow.contrib.slim as slim
9 import os, sys
10 import numpy as np
11 import time
12 sys.path.append("../")
13
14 from libs.configs import cfgs
15 # from libs.networks import build_whole_network2
16 from libs.networks import build_whole_network
17 from data.io.read_tfrecord import next_batch
18 from libs.box_utils import show_box_in_tensor
19 from help_utils import tools
20
21 os.environ["CUDA_VISIBLE_DEVICES"] = "2"
22
23
24 def train():
25
26     faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
27                                                        is_training=True)
28
29     with tf.name_scope('get_batch'):
30         img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \
31             next_batch(dataset_name=cfgs.DATASET_NAME,  # 'pascal', 'coco'
32                        batch_size=cfgs.BATCH_SIZE,
33                        shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
34                        is_training=True)
35         gtboxes_and_label = tf.reshape(gtboxes_and_label_batch, [-1, 5])
36
37     biases_regularizer = tf.no_regularizer
38     weights_regularizer = tf.contrib.layers.l2_regularizer(cfgs.WEIGHT_DECAY)
39
40     # list as many types of layers as possible, even if they are not used now
41     with slim.arg_scope([slim.conv2d, slim.conv2d_in_plane, \
42                          slim.conv2d_transpose, slim.separable_conv2d, slim.fully_connected],
43                         weights_regularizer=weights_regularizer,
44                         biases_regularizer=biases_regularizer,
45                         biases_initializer=tf.constant_initializer(0.0)):
46         final_bbox, final_scores, final_category, loss_dict = faster_rcnn.build_whole_detection_network(
47             input_img_batch=img_batch,
48             gtboxes_batch=gtboxes_and_label)
49
50     # ----------------------------------------------------------------------------------------------------build loss
51     weight_decay_loss = tf.add_n(slim.losses.get_regularization_losses())
52     rpn_location_loss = loss_dict['rpn_loc_loss']
53     rpn_cls_loss = loss_dict['rpn_cls_loss']
54     rpn_total_loss = rpn_location_loss + rpn_cls_loss
55
56     fastrcnn_cls_loss = loss_dict['fastrcnn_cls_loss']
57     fastrcnn_loc_loss = loss_dict['fastrcnn_loc_loss']
58     fastrcnn_total_loss = fastrcnn_cls_loss + fastrcnn_loc_loss
59
60     total_loss = rpn_total_loss + fastrcnn_total_loss + weight_decay_loss
61     # ____________________________________________________________________________________________________build loss
62
63
64     # ---------------------------------------------------------------------------------------------------add summary
65
66     tf.summary.scalar('RPN_LOSS/cls_loss', rpn_cls_loss)
67     tf.summary.scalar('RPN_LOSS/location_loss', rpn_location_loss)
68     tf.summary.scalar('RPN_LOSS/rpn_total_loss', rpn_total_loss)
69
70     tf.summary.scalar('FAST_LOSS/fastrcnn_cls_loss', fastrcnn_cls_loss)
71     tf.summary.scalar('FAST_LOSS/fastrcnn_location_loss', fastrcnn_loc_loss)
72     tf.summary.scalar('FAST_LOSS/fastrcnn_total_loss', fastrcnn_total_loss)
73
74     tf.summary.scalar('LOSS/total_loss', total_loss)
75     tf.summary.scalar('LOSS/regular_weights', weight_decay_loss)
76
77     gtboxes_in_img = show_box_in_tensor.draw_boxes_with_categories(img_batch=img_batch,
78                                                                    boxes=gtboxes_and_label[:, :-1],
79                                                                    labels=gtboxes_and_label[:, -1])
80     if cfgs.ADD_BOX_IN_TENSORBOARD:
81         detections_in_img = show_box_in_tensor.draw_boxes_with_categories_and_scores(img_batch=img_batch,
82                                                                                      boxes=final_bbox,
83                                                                                      labels=final_category,
84                                                                                      scores=final_scores)
85         tf.summary.image('Compare/final_detection', detections_in_img)
86     tf.summary.image('Compare/gtboxes', gtboxes_in_img)
87
88     # ___________________________________________________________________________________________________add summary
89
90     global_step = slim.get_or_create_global_step()
91     lr = tf.train.piecewise_constant(global_step,
92                                      boundaries=[np.int64(cfgs.DECAY_STEP[0]), np.int64(cfgs.DECAY_STEP[1])],
93                                      values=[cfgs.LR, cfgs.LR / 10., cfgs.LR / 100.])
94     tf.summary.scalar('lr', lr)
95     optimizer = tf.train.MomentumOptimizer(lr, momentum=cfgs.MOMENTUM)
96     # optimizer = tf.train.AdamOptimizer(lr)
97
98     # ---------------------------------------------------------------------------------------------compute gradients
99     gradients = faster_rcnn.get_gradients(optimizer, total_loss)
100
101     # enlarge_gradients for bias
102     if cfgs.MUTILPY_BIAS_GRADIENT:
103         gradients = faster_rcnn.enlarge_gradients_for_bias(gradients)
104
105     if cfgs.GRADIENT_CLIPPING_BY_NORM:
106         with tf.name_scope('clip_gradients_YJR'):
107             gradients = slim.learning.clip_gradient_norms(gradients,
108                                                           cfgs.GRADIENT_CLIPPING_BY_NORM)
109     # _____________________________________________________________________________________________compute gradients
110
111
112
113     # train_op
114     train_op = optimizer.apply_gradients(grads_and_vars=gradients,
115                                          global_step=global_step)
116     summary_op = tf.summary.merge_all()
117     init_op = tf.group(
118         tf.global_variables_initializer(),
119         tf.local_variables_initializer()
120     )
121
122     restorer, restore_ckpt = faster_rcnn.get_restorer()
123     saver = tf.train.Saver(max_to_keep=30)
124
125     config = tf.ConfigProto()
126     config.gpu_options.allow_growth = True
127     compute_time = 0
128     compute_imgnum = 0
129
130     with tf.Session(config=config) as sess:
131         sess.run(init_op)
132         if not restorer is None:
133             restorer.restore(sess, restore_ckpt)
134             print('restore model')
135         coord = tf.train.Coordinator()
136         threads = tf.train.start_queue_runners(sess, coord)
137
138         summary_path = os.path.join(cfgs.SUMMARY_PATH, cfgs.VERSION)
139         tools.mkdir(summary_path)
140         summary_writer = tf.summary.FileWriter(summary_path, graph=sess.graph)
141         training_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
142
143         for step in range(cfgs.MAX_ITERATION):
144             training_time1 = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
145             if step % cfgs.SHOW_TRAIN_INFO_INTE != 0 and step % cfgs.SMRY_ITER != 0:
146                 _, global_stepnp = sess.run([train_op, global_step])
147
148             else:
149                 if step % cfgs.SHOW_TRAIN_INFO_INTE == 0 and step % cfgs.SMRY_ITER != 0:
150                     start = time.time()
151
152                     _, global_stepnp, img_name, rpnLocLoss, rpnClsLoss, rpnTotalLoss, \
153                     fastrcnnLocLoss, fastrcnnClsLoss, fastrcnnTotalLoss, totalLoss = \
154                         sess.run(
155                             [train_op, global_step, img_name_batch, rpn_location_loss, rpn_cls_loss, rpn_total_loss,
156                              fastrcnn_loc_loss, fastrcnn_cls_loss, fastrcnn_total_loss, total_loss])
157
158                     end = time.time()
159                     compute_time = compute_time + (end - start)
160                     compute_imgnum = compute_imgnum + 1
161                     print(""" {}: step{}    image_name:{} |\t
162                               rpn_loc_loss:{} |\t rpn_cla_loss:{} |\t rpn_total_loss:{} |
163                               fast_rcnn_loc_loss:{} |\t fast_rcnn_cla_loss:{} |\t fast_rcnn_total_loss:{} |
164                               total_loss:{} |\t per_cost_time:{}s""" \
165                           .format(training_time1, global_stepnp, str(img_name[0]), rpnLocLoss, rpnClsLoss,
166                                   rpnTotalLoss, fastrcnnLocLoss, fastrcnnClsLoss, fastrcnnTotalLoss, totalLoss,
167                                   (end - start)))
168                 else:
169                     if step % cfgs.SMRY_ITER == 0:
170                         _, global_stepnp, summary_str = sess.run([train_op, global_step, summary_op])
171                         summary_writer.add_summary(summary_str, global_stepnp)
172                         summary_writer.flush()
173
174             if (step > 0 and step % cfgs.SAVE_WEIGHTS_INTE == 0) or (step == cfgs.MAX_ITERATION - 1):
175
176                 save_dir = os.path.join(cfgs.TRAINED_CKPT, cfgs.VERSION)
177                 if not os.path.exists(save_dir):
178                     os.mkdir(save_dir)
179
180                 #save_ckpt = os.path.join(save_dir, 'voc_' + str(global_stepnp) + 'model.ckpt')
181                 save_ckpt = os.path.join(save_dir, 'pcb_' + str(global_stepnp) + 'model.ckpt')
182                 saver.save(sess, save_ckpt)
183                 print(' weights had been saved')
184         print('average_training_time_per_image is' + str(compute_time / compute_imgnum))
185         print('traning start time is ' + training_time)
186         end_training_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
187         print('traning end time is ' + end_training_time)
188         coord.request_stop()
189         coord.join(threads)
190
191
192 if __name__ == '__main__':
193
194     train()
195
196
197
198
199
200
201
202
203
204