pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / networks / resnet.py
1 # -*- coding: utf-8 -*-
2
3 from __future__ import absolute_import, print_function, division
4
5
6 import tensorflow as tf
7 import tensorflow.contrib.slim as slim
8 from libs.configs import cfgs
9 from tensorflow.contrib.slim.nets import resnet_v1
10 from tensorflow.contrib.slim.nets import resnet_utils
11 from tensorflow.contrib.slim.python.slim.nets.resnet_v1 import resnet_v1_block
12 import tfplot as tfp
13
14
15 def resnet_arg_scope(
16         is_training=True, weight_decay=cfgs.WEIGHT_DECAY, batch_norm_decay=0.997,
17         batch_norm_epsilon=1e-5, batch_norm_scale=True):
18     '''
19
20     In Default, we do not use BN to train resnet, since batch_size is too small.
21     So is_training is False and trainable is False in the batch_norm params.
22
23     '''
24     batch_norm_params = {
25         'is_training': False, 'decay': batch_norm_decay,
26         'epsilon': batch_norm_epsilon, 'scale': batch_norm_scale,
27         'trainable': False,
28         'updates_collections': tf.GraphKeys.UPDATE_OPS
29     }
30
31     with slim.arg_scope(
32             [slim.conv2d],
33             weights_regularizer=slim.l2_regularizer(weight_decay),
34             weights_initializer=slim.variance_scaling_initializer(),
35             trainable=is_training,
36             activation_fn=tf.nn.relu,
37             normalizer_fn=slim.batch_norm,
38             normalizer_params=batch_norm_params):
39         with slim.arg_scope([slim.batch_norm], **batch_norm_params) as arg_sc:
40             return arg_sc
41
42
43 def fusion_two_layer(C_i, P_j, scope):
44     '''
45     i = j+1
46     :param C_i: shape is [1, h, w, c]
47     :param P_j: shape is [1, h/2, w/2, 256]
48     :return:
49     P_i
50     '''
51     with tf.variable_scope(scope):
52         level_name = scope.split('_')[1]
53         h, w = tf.shape(C_i)[1], tf.shape(C_i)[2]
54         upsample_p = tf.image.resize_bilinear(P_j,
55                                               size=[h, w],
56                                               name='up_sample_'+level_name)
57
58         reduce_dim_c = slim.conv2d(C_i,
59                                    num_outputs=256,
60                                    kernel_size=[1, 1], stride=1,
61                                    scope='reduce_dim_'+level_name)
62
63         add_f = 0.5*upsample_p + 0.5*reduce_dim_c
64
65         # P_i = slim.conv2d(add_f,
66         #                   num_outputs=256, kernel_size=[3, 3], stride=1,
67         #                   padding='SAME',
68         #                   scope='fusion_'+level_name)
69         return add_f
70
71
72 def add_heatmap(feature_maps, name):
73     '''
74
75     :param feature_maps:[B, H, W, C]
76     :return:
77     '''
78
79     def figure_attention(activation):
80         fig, ax = tfp.subplots()
81         im = ax.imshow(activation, cmap='jet')
82         fig.colorbar(im)
83         return fig
84
85     heatmap = tf.reduce_sum(feature_maps, axis=-1)
86     heatmap = tf.squeeze(heatmap, axis=0)
87     tfp.summary.plot(name, figure_attention, [heatmap])
88
89
90 def resnet_base(img_batch, scope_name, is_training=True):
91     '''
92     this code is derived from light-head rcnn.
93     https://github.com/zengarden/light_head_rcnn
94
95     It is convenient to freeze blocks. So we adapt this mode.
96     '''
97     if scope_name == 'resnet_v1_50':
98         middle_num_units = 6
99     elif scope_name == 'resnet_v1_101':
100         middle_num_units = 23
101     else:
102         raise NotImplementedError('We only support resnet_v1_50 or resnet_v1_101. Check your network name....yjr')
103
104     blocks = [resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
105               resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
106               resnet_v1_block('block3', base_depth=256, num_units=middle_num_units, stride=2),
107               resnet_v1_block('block4', base_depth=512, num_units=3, stride=1)]
108     # when use fpn . stride list is [1, 2, 2]
109
110     with slim.arg_scope(resnet_arg_scope(is_training=False)):
111         with tf.variable_scope(scope_name, scope_name):
112             # Do the first few layers manually, because 'SAME' padding can behave inconsistently
113             # for images of different sizes: sometimes 0, sometimes 1
114             net = resnet_utils.conv2d_same(
115                 img_batch, 64, 7, stride=2, scope='conv1')
116             net = tf.pad(net, [[0, 0], [1, 1], [1, 1], [0, 0]])
117             net = slim.max_pool2d(
118                 net, [3, 3], stride=2, padding='VALID', scope='pool1')
119
120     not_freezed = [False] * cfgs.FIXED_BLOCKS + (4-cfgs.FIXED_BLOCKS)*[True]
121     # Fixed_Blocks can be 1~3
122
123     with slim.arg_scope(resnet_arg_scope(is_training=(is_training and not_freezed[0]))):
124         C2, end_points_C2 = resnet_v1.resnet_v1(net,
125                                                 blocks[0:1],
126                                                 global_pool=False,
127                                                 include_root_block=False,
128                                                 scope=scope_name)
129
130     # C2 = tf.Print(C2, [tf.shape(C2)], summarize=10, message='C2_shape')
131     add_heatmap(C2, name='Layer2/C2_heat')
132
133     with slim.arg_scope(resnet_arg_scope(is_training=(is_training and not_freezed[1]))):
134         C3, end_points_C3 = resnet_v1.resnet_v1(C2,
135                                                 blocks[1:2],
136                                                 global_pool=False,
137                                                 include_root_block=False,
138                                                 scope=scope_name)
139
140     # C3 = tf.Print(C3, [tf.shape(C3)], summarize=10, message='C3_shape')
141     add_heatmap(C3, name='Layer3/C3_heat')
142     with slim.arg_scope(resnet_arg_scope(is_training=(is_training and not_freezed[2]))):
143         C4, end_points_C4 = resnet_v1.resnet_v1(C3,
144                                                 blocks[2:3],
145                                                 global_pool=False,
146                                                 include_root_block=False,
147                                                 scope=scope_name)
148
149     add_heatmap(C4, name='Layer4/C4_heat')
150
151     # C4 = tf.Print(C4, [tf.shape(C4)], summarize=10, message='C4_shape')
152     with slim.arg_scope(resnet_arg_scope(is_training=is_training)):
153         C5, end_points_C5 = resnet_v1.resnet_v1(C4,
154                                                 blocks[3:4],
155                                                 global_pool=False,
156                                                 include_root_block=False,
157                                                 scope=scope_name)
158     # C5 = tf.Print(C5, [tf.shape(C5)], summarize=10, message='C5_shape')
159     add_heatmap(C5, name='Layer5/C5_heat')
160
161     feature_dict = {'C2': end_points_C2['{}/block1/unit_2/bottleneck_v1'.format(scope_name)],
162                     'C3': end_points_C3['{}/block2/unit_3/bottleneck_v1'.format(scope_name)],
163                     'C4': end_points_C4['{}/block3/unit_{}/bottleneck_v1'.format(scope_name, middle_num_units - 1)],
164                     'C5': end_points_C5['{}/block4/unit_3/bottleneck_v1'.format(scope_name)],
165                     # 'C5': end_points_C5['{}/block4'.format(scope_name)],
166                     }
167
168     # feature_dict = {'C2': C2,
169     #                 'C3': C3,
170     #                 'C4': C4,
171     #                 'C5': C5}
172
173     pyramid_dict = {}
174     with tf.variable_scope('build_pyramid'):
175         with slim.arg_scope([slim.conv2d], weights_regularizer=slim.l2_regularizer(cfgs.WEIGHT_DECAY),
176                             activation_fn=None, normalizer_fn=None):
177
178             P5 = slim.conv2d(C5,
179                              num_outputs=256,
180                              kernel_size=[1, 1],
181                              stride=1, scope='build_P5')
182             if "P6" in cfgs.LEVLES:
183                 P6 = slim.max_pool2d(P5, kernel_size=[1, 1], stride=2, scope='build_P6')
184                 pyramid_dict['P6'] = P6
185
186             pyramid_dict['P5'] = P5
187
188             for level in range(4, 1, -1):  # build [P4, P3, P2]
189
190                 pyramid_dict['P%d' % level] = fusion_two_layer(C_i=feature_dict["C%d" % level],
191                                                                P_j=pyramid_dict["P%d" % (level+1)],
192                                                                scope='build_P%d' % level)
193             for level in range(4, 1, -1):
194                 pyramid_dict['P%d' % level] = slim.conv2d(pyramid_dict['P%d' % level],
195                                                           num_outputs=256, kernel_size=[3, 3], padding="SAME",
196                                                           stride=1, scope="fuse_P%d" % level)
197     for level in range(5, 1, -1):
198         add_heatmap(pyramid_dict['P%d' % level], name='Layer%d/P%d_heat' % (level, level))
199
200     # return [P2, P3, P4, P5, P6]
201     print("we are in Pyramid::-======>>>>")
202     print(cfgs.LEVLES)
203     print("base_anchor_size are: ", cfgs.BASE_ANCHOR_SIZE_LIST)
204     print(20 * "__")
205     return [pyramid_dict[level_name] for level_name in cfgs.LEVLES]
206     # return pyramid_dict  # return the dict. And get each level by key. But ensure the levels are consitant
207     # return list rather than dict, to avoid dict is unordered
208
209
210
211 def restnet_head(inputs, is_training, scope_name):
212     '''
213
214     :param inputs: [minibatch_size, 7, 7, 256]
215     :param is_training:
216     :param scope_name:
217     :return:
218     '''
219
220     with tf.variable_scope('build_fc_layers'):
221
222         # fc1 = slim.conv2d(inputs=inputs,
223         #                   num_outputs=1024,
224         #                   kernel_size=[7, 7],
225         #                   padding='VALID',
226         #                   scope='fc1') # shape is [minibatch_size, 1, 1, 1024]
227         # fc1 = tf.squeeze(fc1, [1, 2], name='squeeze_fc1')
228
229         inputs = slim.flatten(inputs=inputs, scope='flatten_inputs')
230
231         fc1 = slim.fully_connected(inputs, num_outputs=1024, scope='fc1')
232
233         fc2 = slim.fully_connected(fc1, num_outputs=1024, scope='fc2')
234
235         # fc3 = slim.fully_connected(fc2, num_outputs=1024, scope='fc3')
236
237         # we add fc3 to increase the ability of fast-rcnn head
238         return fc2
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268