--- /dev/null
+# -*- coding: utf-8 -*-
+from __future__ import division, print_function, absolute_import
+import sys
+sys.path.append('../../')
+import xml.etree.cElementTree as ET
+import numpy as np
+import tensorflow as tf
+import glob
+import cv2
+from libs.label_name_dict.label_dict import *
+from help_utils.tools import *
+
+#tf.app.flags.DEFINE_string('VOC_dir', '/home/root1/My_Work/Akraino/MEC_BP/Use_cases/PCB_Demo/Tiny-Defect-Detection-for-PCB/data/pcb_train/PCB_DATASET/', 'Voc dir')
+tf.app.flags.DEFINE_string('VOC_dir', '/home/root1/My_Work/Akraino/MEC_BP/Use_cases/PCB_Demo/Tiny-Defect-Detection-for-PCB/data/pcb_test/PCB_DATASET/', 'Voc dir')
+tf.app.flags.DEFINE_string('xml_dir', 'Annotations/Missing_hole', 'xml dir')
+tf.app.flags.DEFINE_string('image_dir', 'images/Missing_hole', 'image dir')
+tf.app.flags.DEFINE_string('save_name', 'test', 'save name')
+tf.app.flags.DEFINE_string('save_dir', '../tfrecord/', 'save name')
+tf.app.flags.DEFINE_string('img_format', '.jpg', 'format of image')
+tf.app.flags.DEFINE_string('dataset', 'pcb', 'dataset')
+FLAGS = tf.app.flags.FLAGS
+
+
+def _int64_feature(value):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def read_xml_gtbox_and_label(xml_path):
+ """
+ :param xml_path: the path of voc xml
+ :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 9],
+ and has [x1, y1, x2, y2, x3, y3, x4, y4, label] in a per row
+ """
+
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ img_width = None
+ img_height = None
+ box_list = []
+ for child_of_root in root:
+ # if child_of_root.tag == 'filename':
+ # assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
+ # + FLAGS.img_format, 'xml_name and img_name cannot match'
+
+ if child_of_root.tag == 'size':
+ for child_item in child_of_root:
+ if child_item.tag == 'width':
+ img_width = int(child_item.text)
+ if child_item.tag == 'height':
+ img_height = int(child_item.text)
+
+ if child_of_root.tag == 'object':
+ label = None
+ for child_item in child_of_root:
+ if child_item.tag == 'name':
+ label = NAME_LABEL_MAP[child_item.text]
+ if child_item.tag == 'bndbox':
+ tmp_box = []
+ for node in child_item:
+ tmp_box.append(int(node.text))
+ assert label is not None, 'label is none, error'
+ tmp_box.append(label)
+ box_list.append(tmp_box)
+
+ gtbox_label = np.array(box_list, dtype=np.int32)
+
+ return img_height, img_width, gtbox_label
+
+
+def convert_pascal_to_tfrecord():
+ xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
+ image_path = FLAGS.VOC_dir + FLAGS.image_dir
+ save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
+ mkdir(FLAGS.save_dir)
+
+ # writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
+ # writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)
+ writer = tf.python_io.TFRecordWriter(path=save_path)
+ for count, xml in enumerate(glob.glob(xml_path + '/*.xml')):
+ # to avoid path error in different development platform
+ xml = xml.replace('\\', '/')
+
+ img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format
+ img_path = image_path + '/' + img_name
+
+ if not os.path.exists(img_path):
+ print('{} is not exist!'.format(img_path))
+ continue
+
+ img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)
+
+ # img = np.array(Image.open(img_path))
+ img = cv2.imread(img_path)[:, :, ::-1]
+
+ feature = tf.train.Features(feature={
+ # do not need encode() in linux
+ # 'img_name': _bytes_feature(img_name.encode()),
+ 'img_name': _bytes_feature(img_name.encode()),
+ 'img_height': _int64_feature(img_height),
+ 'img_width': _int64_feature(img_width),
+ 'img': _bytes_feature(img.tostring()),
+ 'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
+ 'num_objects': _int64_feature(gtbox_label.shape[0])
+ })
+
+ example = tf.train.Example(features=feature)
+
+ writer.write(example.SerializeToString())
+
+ view_bar('Conversion progress', count + 1, len(glob.glob(xml_path + '/*.xml')))
+
+ print('\nConversion is complete!')
+
+
+if __name__ == '__main__':
+ # xml_path = '../data/dataset/VOCdevkit/VOC2007/Annotations/000005.xml'
+ # read_xml_gtbox_and_label(xml_path)
+
+ convert_pascal_to_tfrecord()