X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Fdata%2Fio%2Fread_tfrecord.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Fdata%2Fio%2Fread_tfrecord.py;h=8f1605fbcf157ec7bcc0002a1ff139528f42640d;hb=a785567fb9acfc68536767d20f60ba917ae85aa1;hp=0000000000000000000000000000000000000000;hpb=94a133e696b9b2a7f73544462c2714986fa7ab4a;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/data/io/read_tfrecord.py b/example-apps/PDD/pcb-defect-detection/data/io/read_tfrecord.py new file mode 100755 index 0000000..8f1605f --- /dev/null +++ b/example-apps/PDD/pcb-defect-detection/data/io/read_tfrecord.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np +import tensorflow as tf +import os +from data.io import image_preprocess +from libs.configs import cfgs + +def read_single_example_and_decode(filename_queue): + + # tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB) + + # reader = tf.TFRecordReader(options=tfrecord_options) + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + + features = tf.parse_single_example( + serialized=serialized_example, + features={ + 'img_name': tf.FixedLenFeature([], tf.string), + 'img_height': tf.FixedLenFeature([], tf.int64), + 'img_width': tf.FixedLenFeature([], tf.int64), + 'img': tf.FixedLenFeature([], tf.string), + 'gtboxes_and_label': tf.FixedLenFeature([], tf.string), + 'num_objects': tf.FixedLenFeature([], tf.int64) + } + ) + img_name = features['img_name'] + img_height = tf.cast(features['img_height'], tf.int32) + img_width = tf.cast(features['img_width'], tf.int32) + img = tf.decode_raw(features['img'], tf.uint8) + + img = tf.reshape(img, shape=[img_height, img_width, 3]) + + gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32) + gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 5]) + + num_objects = tf.cast(features['num_objects'], tf.int32) + return img_name, img, gtboxes_and_label, num_objects + + +def read_and_prepocess_single_img(filename_queue, shortside_len, is_training): + + img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue) + + img = tf.cast(img, tf.float32) + + if is_training: + img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label, + target_shortside_len=shortside_len, + length_limitation=cfgs.IMG_MAX_LENGTH) + img, gtboxes_and_label = image_preprocess.random_flip_left_right(img_tensor=img, + gtboxes_and_label=gtboxes_and_label) + + else: + img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label, + target_shortside_len=shortside_len, + length_limitation=cfgs.IMG_MAX_LENGTH) + img = img - tf.constant([[cfgs.PIXEL_MEAN]]) # sub pixel mean at last + return img_name, img, gtboxes_and_label, num_objects + + +def next_batch(dataset_name, batch_size, shortside_len, is_training): + ''' + :return: + img_name_batch: shape(1, 1) + img_batch: shape:(1, new_imgH, new_imgW, C) + gtboxes_and_label_batch: shape(1, Num_Of_objects, 5] .each row is [x1, y1, x2, y2, label] + ''' + assert batch_size == 1, "we only support batch_size is 1.We may support large batch_size in the future" + + if dataset_name not in ['ship', 'spacenet', 'pascal', 'coco','pcb']: + raise ValueError('dataSet name must be in pascal, coco spacenet and ship') + + if is_training: + pattern = os.path.join('../data/tfrecord', dataset_name + '_train*') + else: + pattern = os.path.join('../data/tfrecord', dataset_name + '_test*') + + print('tfrecord path is -->', os.path.abspath(pattern)) + + filename_tensorlist = tf.train.match_filenames_once(pattern) + + filename_queue = tf.train.string_input_producer(filename_tensorlist) + + img_name, img, gtboxes_and_label, num_obs = read_and_prepocess_single_img(filename_queue, shortside_len, + is_training=is_training) + img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = \ + tf.train.batch( + [img_name, img, gtboxes_and_label, num_obs], + batch_size=batch_size, + capacity=1, + num_threads=1, + dynamic_pad=True) + return img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch