pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / data / io / convert_data_to_tfrecord.py
diff --git a/example-apps/PDD/pcb-defect-detection/data/io/convert_data_to_tfrecord.py b/example-apps/PDD/pcb-defect-detection/data/io/convert_data_to_tfrecord.py
new file mode 100755 (executable)
index 0000000..17bb8d6
--- /dev/null
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+from __future__ import division, print_function, absolute_import
+import sys
+sys.path.append('../../')
+import xml.etree.cElementTree as ET
+import numpy as np
+import tensorflow as tf
+import glob
+import cv2
+from libs.label_name_dict.label_dict import *
+from help_utils.tools import *
+
+#tf.app.flags.DEFINE_string('VOC_dir', '/home/root1/My_Work/Akraino/MEC_BP/Use_cases/PCB_Demo/Tiny-Defect-Detection-for-PCB/data/pcb_train/PCB_DATASET/', 'Voc dir')
+tf.app.flags.DEFINE_string('VOC_dir', '/home/root1/My_Work/Akraino/MEC_BP/Use_cases/PCB_Demo/Tiny-Defect-Detection-for-PCB/data/pcb_test/PCB_DATASET/', 'Voc dir')
+tf.app.flags.DEFINE_string('xml_dir', 'Annotations/Missing_hole', 'xml dir')
+tf.app.flags.DEFINE_string('image_dir', 'images/Missing_hole', 'image dir')
+tf.app.flags.DEFINE_string('save_name', 'test', 'save name')
+tf.app.flags.DEFINE_string('save_dir', '../tfrecord/', 'save name')
+tf.app.flags.DEFINE_string('img_format', '.jpg', 'format of image')
+tf.app.flags.DEFINE_string('dataset', 'pcb', 'dataset')
+FLAGS = tf.app.flags.FLAGS
+
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def read_xml_gtbox_and_label(xml_path):
+    """
+    :param xml_path: the path of voc xml
+    :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 9],
+           and has [x1, y1, x2, y2, x3, y3, x4, y4, label] in a per row
+    """
+
+    tree = ET.parse(xml_path)
+    root = tree.getroot()
+    img_width = None
+    img_height = None
+    box_list = []
+    for child_of_root in root:
+        # if child_of_root.tag == 'filename':
+        #     assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
+        #                                  + FLAGS.img_format, 'xml_name and img_name cannot match'
+
+        if child_of_root.tag == 'size':
+            for child_item in child_of_root:
+                if child_item.tag == 'width':
+                    img_width = int(child_item.text)
+                if child_item.tag == 'height':
+                    img_height = int(child_item.text)
+
+        if child_of_root.tag == 'object':
+            label = None
+            for child_item in child_of_root:
+                if child_item.tag == 'name':
+                    label = NAME_LABEL_MAP[child_item.text]
+                if child_item.tag == 'bndbox':
+                    tmp_box = []
+                    for node in child_item:
+                        tmp_box.append(int(node.text))
+                    assert label is not None, 'label is none, error'
+                    tmp_box.append(label)
+                    box_list.append(tmp_box)
+
+    gtbox_label = np.array(box_list, dtype=np.int32)
+
+    return img_height, img_width, gtbox_label
+
+
+def convert_pascal_to_tfrecord():
+    xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
+    image_path = FLAGS.VOC_dir + FLAGS.image_dir
+    save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
+    mkdir(FLAGS.save_dir)
+
+    # writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
+    # writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)
+    writer = tf.python_io.TFRecordWriter(path=save_path)
+    for count, xml in enumerate(glob.glob(xml_path + '/*.xml')):
+        # to avoid path error in different development platform
+        xml = xml.replace('\\', '/')
+
+        img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format
+        img_path = image_path + '/' + img_name
+
+        if not os.path.exists(img_path):
+            print('{} is not exist!'.format(img_path))
+            continue
+
+        img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)
+
+        # img = np.array(Image.open(img_path))
+        img = cv2.imread(img_path)[:, :, ::-1]
+
+        feature = tf.train.Features(feature={
+            # do not need encode() in linux
+            # 'img_name': _bytes_feature(img_name.encode()),
+            'img_name': _bytes_feature(img_name.encode()),
+            'img_height': _int64_feature(img_height),
+            'img_width': _int64_feature(img_width),
+            'img': _bytes_feature(img.tostring()),
+            'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
+            'num_objects': _int64_feature(gtbox_label.shape[0])
+        })
+
+        example = tf.train.Example(features=feature)
+
+        writer.write(example.SerializeToString())
+
+        view_bar('Conversion progress', count + 1, len(glob.glob(xml_path + '/*.xml')))
+
+    print('\nConversion is complete!')
+
+
+if __name__ == '__main__':
+    # xml_path = '../data/dataset/VOCdevkit/VOC2007/Annotations/000005.xml'
+    # read_xml_gtbox_and_label(xml_path)
+
+    convert_pascal_to_tfrecord()