pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / data / io / read_tfrecord.py
1 # -*- coding: utf-8 -*-
2
3 from __future__ import absolute_import
4 from __future__ import print_function
5 from __future__ import division
6
7 import numpy as np
8 import tensorflow as tf
9 import os
10 from data.io import image_preprocess
11 from libs.configs import cfgs
12
13 def read_single_example_and_decode(filename_queue):
14
15     # tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
16
17     # reader = tf.TFRecordReader(options=tfrecord_options)
18     reader = tf.TFRecordReader()
19     _, serialized_example = reader.read(filename_queue)
20
21     features = tf.parse_single_example(
22         serialized=serialized_example,
23         features={
24             'img_name': tf.FixedLenFeature([], tf.string),
25             'img_height': tf.FixedLenFeature([], tf.int64),
26             'img_width': tf.FixedLenFeature([], tf.int64),
27             'img': tf.FixedLenFeature([], tf.string),
28             'gtboxes_and_label': tf.FixedLenFeature([], tf.string),
29             'num_objects': tf.FixedLenFeature([], tf.int64)
30         }
31     )
32     img_name = features['img_name']
33     img_height = tf.cast(features['img_height'], tf.int32)
34     img_width = tf.cast(features['img_width'], tf.int32)
35     img = tf.decode_raw(features['img'], tf.uint8)
36
37     img = tf.reshape(img, shape=[img_height, img_width, 3])
38
39     gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32)
40     gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 5])
41
42     num_objects = tf.cast(features['num_objects'], tf.int32)
43     return img_name, img, gtboxes_and_label, num_objects
44
45
46 def read_and_prepocess_single_img(filename_queue, shortside_len, is_training):
47
48     img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue)
49
50     img = tf.cast(img, tf.float32)
51
52     if is_training:
53         img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label,
54                                                                     target_shortside_len=shortside_len,
55                                                                     length_limitation=cfgs.IMG_MAX_LENGTH)
56         img, gtboxes_and_label = image_preprocess.random_flip_left_right(img_tensor=img,
57                                                                          gtboxes_and_label=gtboxes_and_label)
58
59     else:
60         img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label,
61                                                                     target_shortside_len=shortside_len,
62                                                                     length_limitation=cfgs.IMG_MAX_LENGTH)
63     img = img - tf.constant([[cfgs.PIXEL_MEAN]])  # sub pixel mean at last
64     return img_name, img, gtboxes_and_label, num_objects
65
66
67 def next_batch(dataset_name, batch_size, shortside_len, is_training):
68     '''
69     :return:
70     img_name_batch: shape(1, 1)
71     img_batch: shape:(1, new_imgH, new_imgW, C)
72     gtboxes_and_label_batch: shape(1, Num_Of_objects, 5] .each row is [x1, y1, x2, y2, label]
73     '''
74     assert batch_size == 1, "we only support batch_size is 1.We may support large batch_size in the future"
75
76     if dataset_name not in ['ship', 'spacenet', 'pascal', 'coco','pcb']:
77         raise ValueError('dataSet name must be in pascal, coco spacenet and ship')
78
79     if is_training:
80         pattern = os.path.join('../data/tfrecord', dataset_name + '_train*')
81     else:
82         pattern = os.path.join('../data/tfrecord', dataset_name + '_test*')
83
84     print('tfrecord path is -->', os.path.abspath(pattern))
85
86     filename_tensorlist = tf.train.match_filenames_once(pattern)
87
88     filename_queue = tf.train.string_input_producer(filename_tensorlist)
89
90     img_name, img, gtboxes_and_label, num_obs = read_and_prepocess_single_img(filename_queue, shortside_len,
91                                                                               is_training=is_training)
92     img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = \
93         tf.train.batch(
94                        [img_name, img, gtboxes_and_label, num_obs],
95                        batch_size=batch_size,
96                        capacity=1,
97                        num_threads=1,
98                        dynamic_pad=True)
99     return img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch