17bb8d6f9b2c3fee55b9ee4b4b81490b519b258c
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / data / io / convert_data_to_tfrecord.py
1 # -*- coding: utf-8 -*-
2 from __future__ import division, print_function, absolute_import
3 import sys
4 sys.path.append('../../')
5 import xml.etree.cElementTree as ET
6 import numpy as np
7 import tensorflow as tf
8 import glob
9 import cv2
10 from libs.label_name_dict.label_dict import *
11 from help_utils.tools import *
12
13 #tf.app.flags.DEFINE_string('VOC_dir', '/home/root1/My_Work/Akraino/MEC_BP/Use_cases/PCB_Demo/Tiny-Defect-Detection-for-PCB/data/pcb_train/PCB_DATASET/', 'Voc dir')
14 tf.app.flags.DEFINE_string('VOC_dir', '/home/root1/My_Work/Akraino/MEC_BP/Use_cases/PCB_Demo/Tiny-Defect-Detection-for-PCB/data/pcb_test/PCB_DATASET/', 'Voc dir')
15 tf.app.flags.DEFINE_string('xml_dir', 'Annotations/Missing_hole', 'xml dir')
16 tf.app.flags.DEFINE_string('image_dir', 'images/Missing_hole', 'image dir')
17 tf.app.flags.DEFINE_string('save_name', 'test', 'save name')
18 tf.app.flags.DEFINE_string('save_dir', '../tfrecord/', 'save name')
19 tf.app.flags.DEFINE_string('img_format', '.jpg', 'format of image')
20 tf.app.flags.DEFINE_string('dataset', 'pcb', 'dataset')
21 FLAGS = tf.app.flags.FLAGS
22
23
24 def _int64_feature(value):
25     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
26
27
28 def _bytes_feature(value):
29     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
30
31
32 def read_xml_gtbox_and_label(xml_path):
33     """
34     :param xml_path: the path of voc xml
35     :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 9],
36            and has [x1, y1, x2, y2, x3, y3, x4, y4, label] in a per row
37     """
38
39     tree = ET.parse(xml_path)
40     root = tree.getroot()
41     img_width = None
42     img_height = None
43     box_list = []
44     for child_of_root in root:
45         # if child_of_root.tag == 'filename':
46         #     assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
47         #                                  + FLAGS.img_format, 'xml_name and img_name cannot match'
48
49         if child_of_root.tag == 'size':
50             for child_item in child_of_root:
51                 if child_item.tag == 'width':
52                     img_width = int(child_item.text)
53                 if child_item.tag == 'height':
54                     img_height = int(child_item.text)
55
56         if child_of_root.tag == 'object':
57             label = None
58             for child_item in child_of_root:
59                 if child_item.tag == 'name':
60                     label = NAME_LABEL_MAP[child_item.text]
61                 if child_item.tag == 'bndbox':
62                     tmp_box = []
63                     for node in child_item:
64                         tmp_box.append(int(node.text))
65                     assert label is not None, 'label is none, error'
66                     tmp_box.append(label)
67                     box_list.append(tmp_box)
68
69     gtbox_label = np.array(box_list, dtype=np.int32)
70
71     return img_height, img_width, gtbox_label
72
73
74 def convert_pascal_to_tfrecord():
75     xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
76     image_path = FLAGS.VOC_dir + FLAGS.image_dir
77     save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
78     mkdir(FLAGS.save_dir)
79
80     # writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
81     # writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)
82     writer = tf.python_io.TFRecordWriter(path=save_path)
83     for count, xml in enumerate(glob.glob(xml_path + '/*.xml')):
84         # to avoid path error in different development platform
85         xml = xml.replace('\\', '/')
86
87         img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format
88         img_path = image_path + '/' + img_name
89
90         if not os.path.exists(img_path):
91             print('{} is not exist!'.format(img_path))
92             continue
93
94         img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)
95
96         # img = np.array(Image.open(img_path))
97         img = cv2.imread(img_path)[:, :, ::-1]
98
99         feature = tf.train.Features(feature={
100             # do not need encode() in linux
101             # 'img_name': _bytes_feature(img_name.encode()),
102             'img_name': _bytes_feature(img_name.encode()),
103             'img_height': _int64_feature(img_height),
104             'img_width': _int64_feature(img_width),
105             'img': _bytes_feature(img.tostring()),
106             'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
107             'num_objects': _int64_feature(gtbox_label.shape[0])
108         })
109
110         example = tf.train.Example(features=feature)
111
112         writer.write(example.SerializeToString())
113
114         view_bar('Conversion progress', count + 1, len(glob.glob(xml_path + '/*.xml')))
115
116     print('\nConversion is complete!')
117
118
119 if __name__ == '__main__':
120     # xml_path = '../data/dataset/VOCdevkit/VOC2007/Annotations/000005.xml'
121     # read_xml_gtbox_and_label(xml_path)
122
123     convert_pascal_to_tfrecord()