d01c54b11d51651c548e8f85de658303087c9117
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / box_utils / anchor_utils.py
1 # -*- coding: utf-8 -*-
2 from __future__ import absolute_import, print_function, division
3
4 import tensorflow as tf
5 from libs.configs import cfgs
6
7
8 def make_anchors(base_anchor_size, anchor_scales, anchor_ratios,
9                  featuremap_height, featuremap_width,
10                  stride, name='make_anchors'):
11     '''
12     :param base_anchor_size:256
13     :param anchor_scales:
14     :param anchor_ratios:
15     :param featuremap_height:
16     :param featuremap_width:
17     :param stride:
18     :return:
19     '''
20     with tf.variable_scope(name):
21         base_anchor = tf.constant([0, 0, base_anchor_size, base_anchor_size], tf.float32)  # [x_center, y_center, w, h]
22
23         ws, hs = enum_ratios(enum_scales(base_anchor, anchor_scales),
24                              anchor_ratios)  # per locations ws and hs
25
26         # featuremap_height = tf.Print(featuremap_height,
27         #                              [featuremap_height, featuremap_width], summarize=10,
28         #                              message=name+"_SHAPE***")
29
30         x_centers = tf.range(featuremap_width, dtype=tf.float32) * stride
31         y_centers = tf.range(featuremap_height, dtype=tf.float32) * stride
32
33         if cfgs.USE_CENTER_OFFSET:
34             x_centers = x_centers + stride/2.
35             y_centers = y_centers + stride/2.
36
37         x_centers, y_centers = tf.meshgrid(x_centers, y_centers)
38
39         ws, x_centers = tf.meshgrid(ws, x_centers)
40         hs, y_centers = tf.meshgrid(hs, y_centers)
41
42         anchor_centers = tf.stack([x_centers, y_centers], 2)
43         anchor_centers = tf.reshape(anchor_centers, [-1, 2])
44
45         box_sizes = tf.stack([ws, hs], axis=2)
46         box_sizes = tf.reshape(box_sizes, [-1, 2])
47         # anchors = tf.concat([anchor_centers, box_sizes], axis=1)
48         anchors = tf.concat([anchor_centers - 0.5*box_sizes,
49                              anchor_centers + 0.5*box_sizes], axis=1)
50         return anchors
51
52
53 def enum_scales(base_anchor, anchor_scales):
54
55     anchor_scales = base_anchor * tf.constant(anchor_scales, dtype=tf.float32, shape=(len(anchor_scales), 1))
56
57     return anchor_scales
58
59
60 def enum_ratios(anchors, anchor_ratios):
61     '''
62     ratio = h /w
63     :param anchors:
64     :param anchor_ratios:
65     :return:
66     '''
67     ws = anchors[:, 2]  # for base anchor: w == h
68     hs = anchors[:, 3]
69     sqrt_ratios = tf.sqrt(tf.constant(anchor_ratios))
70
71     ws = tf.reshape(ws / sqrt_ratios[:, tf.newaxis], [-1, 1])
72     hs = tf.reshape(hs * sqrt_ratios[:, tf.newaxis], [-1, 1])
73
74     return hs, ws
75
76
77 if __name__ == '__main__':
78     import os
79     os.environ["CUDA_VISIBLE_DEVICES"] = '0'
80     base_anchor_size = 256
81     anchor_scales = [1.0]
82     anchor_ratios = [0.5, 2.0, 1.0]
83     anchors = make_anchors(base_anchor_size=base_anchor_size, anchor_ratios=anchor_ratios,
84                            anchor_scales=anchor_scales,
85                            featuremap_width=32,
86                            featuremap_height=63,
87                            stride=16)
88     init = tf.global_variables_initializer()
89     with tf.Session() as sess:
90         sess.run(init)
91         anchor_result = sess.run(anchors)
92         print (anchor_result.shape)