X-Git-Url: https://gerrit.akraino.org/r/gitweb?a=blobdiff_plain;f=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fbox_utils%2Ftf_ops.py;fp=example-apps%2FPDD%2Fpcb-defect-detection%2Flibs%2Fbox_utils%2Ftf_ops.py;h=86d945af0747842ff7460c371f7cb133f458e30e;hb=a785567fb9acfc68536767d20f60ba917ae85aa1;hp=0000000000000000000000000000000000000000;hpb=94a133e696b9b2a7f73544462c2714986fa7ab4a;p=ealt-edge.git diff --git a/example-apps/PDD/pcb-defect-detection/libs/box_utils/tf_ops.py b/example-apps/PDD/pcb-defect-detection/libs/box_utils/tf_ops.py new file mode 100755 index 0000000..86d945a --- /dev/null +++ b/example-apps/PDD/pcb-defect-detection/libs/box_utils/tf_ops.py @@ -0,0 +1,57 @@ +# -*- coding:utf-8 -*- + +from __future__ import absolute_import, print_function, division + +import tensorflow as tf + +''' +all of these ops are derived from tenosrflow Object Detection API +''' +def indices_to_dense_vector(indices, + size, + indices_value=1., + default_value=0, + dtype=tf.float32): + """Creates dense vector with indices set to specific (the para "indices_value" ) and rest to zeros. + + This function exists because it is unclear if it is safe to use + tf.sparse_to_dense(indices, [size], 1, validate_indices=False) + with indices which are not ordered. + This function accepts a dynamic size (e.g. tf.shape(tensor)[0]) + + Args: + indices: 1d Tensor with integer indices which are to be set to + indices_values. + size: scalar with size (integer) of output Tensor. + indices_value: values of elements specified by indices in the output vector + default_value: values of other elements in the output vector. + dtype: data type. + + Returns: + dense 1D Tensor of shape [size] with indices set to indices_values and the + rest set to default_value. + """ + size = tf.to_int32(size) + zeros = tf.ones([size], dtype=dtype) * default_value + values = tf.ones_like(indices, dtype=dtype) * indices_value + + return tf.dynamic_stitch([tf.range(size), tf.to_int32(indices)], + [zeros, values]) + + + + +def test_plt(): + from PIL import Image + import matplotlib.pyplot as plt + import numpy as np + + a = np.random.rand(20, 30) + print (a.shape) + # plt.subplot() + b = plt.imshow(a) + plt.show() + + +if __name__ == '__main__': + test_plt()