pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / data / io / image_preprocess.py
1 # -*- coding: utf-8 -*-
2
3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
6
7 import tensorflow as tf
8
9 import numpy as np
10
11
12 def max_length_limitation(length, length_limitation):
13     return tf.cond(tf.less(length, length_limitation),
14                    true_fn=lambda: length,
15                    false_fn=lambda: length_limitation)
16
17 def short_side_resize(img_tensor, gtboxes_and_label, target_shortside_len, length_limitation=1200):
18     '''
19
20     :param img_tensor:[h, w, c], gtboxes_and_label:[-1, 5].  gtboxes: [xmin, ymin, xmax, ymax]
21     :param target_shortside_len:
22     :param length_limitation: set max length to avoid OUT OF MEMORY
23     :return:
24     '''
25     img_h, img_w = tf.shape(img_tensor)[0], tf.shape(img_tensor)[1]
26     new_h, new_w = tf.cond(tf.less(img_h, img_w),
27                            true_fn=lambda: (target_shortside_len,
28                                             max_length_limitation(target_shortside_len * img_w // img_h, length_limitation)),
29                            false_fn=lambda: (max_length_limitation(target_shortside_len * img_h // img_w, length_limitation),
30                                              target_shortside_len))
31
32     img_tensor = tf.expand_dims(img_tensor, axis=0)
33     img_tensor = tf.image.resize_bilinear(img_tensor, [new_h, new_w])
34
35     xmin, ymin, xmax, ymax, label = tf.unstack(gtboxes_and_label, axis=1)
36
37     new_xmin, new_ymin = xmin * new_w // img_w, ymin * new_h // img_h
38     new_xmax, new_ymax = xmax * new_w // img_w, ymax * new_h // img_h
39     img_tensor = tf.squeeze(img_tensor, axis=0)  # ensure image tensor rank is 3
40
41     return img_tensor, tf.transpose(tf.stack([new_xmin, new_ymin, new_xmax, new_ymax, label], axis=0))
42
43
44 def short_side_resize_for_inference_data(img_tensor, target_shortside_len, length_limitation=1200, is_resize=True):
45     if is_resize:
46       img_h, img_w = tf.shape(img_tensor)[0], tf.shape(img_tensor)[1]
47
48       new_h, new_w = tf.cond(tf.less(img_h, img_w),
49                              true_fn=lambda: (target_shortside_len,
50                                               max_length_limitation(target_shortside_len * img_w // img_h, length_limitation)),
51                              false_fn=lambda: (max_length_limitation(target_shortside_len * img_h // img_w, length_limitation),
52                                                target_shortside_len))
53
54       img_tensor = tf.expand_dims(img_tensor, axis=0)
55       img_tensor = tf.image.resize_bilinear(img_tensor, [new_h, new_w])
56
57       img_tensor = tf.squeeze(img_tensor, axis=0)  # ensure image tensor rank is 3
58     return img_tensor
59
60 def flip_left_to_right(img_tensor, gtboxes_and_label):
61
62     h, w = tf.shape(img_tensor)[0], tf.shape(img_tensor)[1]
63
64     img_tensor = tf.image.flip_left_right(img_tensor)
65
66     xmin, ymin, xmax, ymax, label = tf.unstack(gtboxes_and_label, axis=1)
67     new_xmax = w - xmin
68     new_xmin = w - xmax
69
70     return img_tensor, tf.transpose(tf.stack([new_xmin, ymin, new_xmax, ymax, label], axis=0))
71
72 def random_flip_left_right(img_tensor, gtboxes_and_label):
73     img_tensor, gtboxes_and_label= tf.cond(tf.less(tf.random_uniform(shape=[], minval=0, maxval=1), 0.5),
74                                             lambda: flip_left_to_right(img_tensor, gtboxes_and_label),
75                                             lambda: (img_tensor, gtboxes_and_label))
76
77     return img_tensor,  gtboxes_and_label
78
79
80