X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Fdata%2Fio%2Fconvert_data_to_tfrecord.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Fdata%2Fio%2Fconvert_data_to_tfrecord.py;h=17bb8d6f9b2c3fee55b9ee4b4b81490b519b258c;hb=a785567fb9acfc68536767d20f60ba917ae85aa1;hp=0000000000000000000000000000000000000000;hpb=94a133e696b9b2a7f73544462c2714986fa7ab4a;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/data/io/convert_data_to_tfrecord.py b/example-apps/PDD/pcb-defect-detection/data/io/convert_data_to_tfrecord.py new file mode 100755 index 0000000..17bb8d6 --- /dev/null +++ b/example-apps/PDD/pcb-defect-detection/data/io/convert_data_to_tfrecord.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +from __future__ import division, print_function, absolute_import +import sys +sys.path.append('../../') +import xml.etree.cElementTree as ET +import numpy as np +import tensorflow as tf +import glob +import cv2 +from libs.label_name_dict.label_dict import * +from help_utils.tools import * + +#tf.app.flags.DEFINE_string('VOC_dir', '/home/root1/My_Work/Akraino/MEC_BP/Use_cases/PCB_Demo/Tiny-Defect-Detection-for-PCB/data/pcb_train/PCB_DATASET/', 'Voc dir') +tf.app.flags.DEFINE_string('VOC_dir', '/home/root1/My_Work/Akraino/MEC_BP/Use_cases/PCB_Demo/Tiny-Defect-Detection-for-PCB/data/pcb_test/PCB_DATASET/', 'Voc dir') +tf.app.flags.DEFINE_string('xml_dir', 'Annotations/Missing_hole', 'xml dir') +tf.app.flags.DEFINE_string('image_dir', 'images/Missing_hole', 'image dir') +tf.app.flags.DEFINE_string('save_name', 'test', 'save name') +tf.app.flags.DEFINE_string('save_dir', '../tfrecord/', 'save name') +tf.app.flags.DEFINE_string('img_format', '.jpg', 'format of image') +tf.app.flags.DEFINE_string('dataset', 'pcb', 'dataset') +FLAGS = tf.app.flags.FLAGS + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def read_xml_gtbox_and_label(xml_path): + """ + :param xml_path: the path of voc xml + :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 9], + and has [x1, y1, x2, y2, x3, y3, x4, y4, label] in a per row + """ + + tree = ET.parse(xml_path) + root = tree.getroot() + img_width = None + img_height = None + box_list = [] + for child_of_root in root: + # if child_of_root.tag == 'filename': + # assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \ + # + FLAGS.img_format, 'xml_name and img_name cannot match' + + if child_of_root.tag == 'size': + for child_item in child_of_root: + if child_item.tag == 'width': + img_width = int(child_item.text) + if child_item.tag == 'height': + img_height = int(child_item.text) + + if child_of_root.tag == 'object': + label = None + for child_item in child_of_root: + if child_item.tag == 'name': + label = NAME_LABEL_MAP[child_item.text] + if child_item.tag == 'bndbox': + tmp_box = [] + for node in child_item: + tmp_box.append(int(node.text)) + assert label is not None, 'label is none, error' + tmp_box.append(label) + box_list.append(tmp_box) + + gtbox_label = np.array(box_list, dtype=np.int32) + + return img_height, img_width, gtbox_label + + +def convert_pascal_to_tfrecord(): + xml_path = FLAGS.VOC_dir + FLAGS.xml_dir + image_path = FLAGS.VOC_dir + FLAGS.image_dir + save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord' + mkdir(FLAGS.save_dir) + + # writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB) + # writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options) + writer = tf.python_io.TFRecordWriter(path=save_path) + for count, xml in enumerate(glob.glob(xml_path + '/*.xml')): + # to avoid path error in different development platform + xml = xml.replace('\\', '/') + + img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format + img_path = image_path + '/' + img_name + + if not os.path.exists(img_path): + print('{} is not exist!'.format(img_path)) + continue + + img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml) + + # img = np.array(Image.open(img_path)) + img = cv2.imread(img_path)[:, :, ::-1] + + feature = tf.train.Features(feature={ + # do not need encode() in linux + # 'img_name': _bytes_feature(img_name.encode()), + 'img_name': _bytes_feature(img_name.encode()), + 'img_height': _int64_feature(img_height), + 'img_width': _int64_feature(img_width), + 'img': _bytes_feature(img.tostring()), + 'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()), + 'num_objects': _int64_feature(gtbox_label.shape[0]) + }) + + example = tf.train.Example(features=feature) + + writer.write(example.SerializeToString()) + + view_bar('Conversion progress', count + 1, len(glob.glob(xml_path + '/*.xml'))) + + print('\nConversion is complete!') + + +if __name__ == '__main__': + # xml_path = '../data/dataset/VOCdevkit/VOC2007/Annotations/000005.xml' + # read_xml_gtbox_and_label(xml_path) + + convert_pascal_to_tfrecord()