pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / libs / box_utils / tf_ops.py
diff --git a/example-apps/PDD/pcb-defect-detection/libs/box_utils/tf_ops.py b/example-apps/PDD/pcb-defect-detection/libs/box_utils/tf_ops.py
new file mode 100755 (executable)
index 0000000..86d945a
--- /dev/null
@@ -0,0 +1,57 @@
+# -*- coding:utf-8 -*-
+
+from __future__ import absolute_import, print_function, division
+
+import tensorflow as tf
+
+'''
+all of these ops are derived from tenosrflow Object Detection API
+'''
+def indices_to_dense_vector(indices,
+                            size,
+                            indices_value=1.,
+                            default_value=0,
+                            dtype=tf.float32):
+  """Creates dense vector with indices set to specific (the para "indices_value" ) and rest to zeros.
+
+  This function exists because it is unclear if it is safe to use
+    tf.sparse_to_dense(indices, [size], 1, validate_indices=False)
+  with indices which are not ordered.
+  This function accepts a dynamic size (e.g. tf.shape(tensor)[0])
+
+  Args:
+    indices: 1d Tensor with integer indices which are to be set to
+        indices_values.
+    size: scalar with size (integer) of output Tensor.
+    indices_value: values of elements specified by indices in the output vector
+    default_value: values of other elements in the output vector.
+    dtype: data type.
+
+  Returns:
+    dense 1D Tensor of shape [size] with indices set to indices_values and the
+        rest set to default_value.
+  """
+  size = tf.to_int32(size)
+  zeros = tf.ones([size], dtype=dtype) * default_value
+  values = tf.ones_like(indices, dtype=dtype) * indices_value
+
+  return tf.dynamic_stitch([tf.range(size), tf.to_int32(indices)],
+                           [zeros, values])
+
+
+
+
+def test_plt():
+  from PIL import Image
+  import matplotlib.pyplot as plt
+  import numpy as np
+
+  a = np.random.rand(20, 30)
+  print (a.shape)
+  # plt.subplot()
+  b = plt.imshow(a)
+  plt.show()
+
+
+if __name__ == '__main__':
+  test_plt()