Advertisement
Guest User

Untitled

a guest
Jul 17th, 2019
156
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.98 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3.  
  4. from PIL import Image
  5.  
  6. def load_graph(frozen_graph_filename):
  7. """
  8. Args:
  9. frozen_graph_filename (str): Full path to the .pb file.
  10. """
  11. # We load the protobuf file from the disk and parse it to retrieve the
  12. # unserialized graph_def
  13. with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
  14. graph_def = tf.GraphDef()
  15. graph_def.ParseFromString(f.read())
  16.  
  17. # Then, we import the graph_def into a new Graph and returns it
  18. with tf.Graph().as_default() as graph:
  19. # The name var will prefix every op/nodes in your graph
  20. # Since we load everything in a new graph, this is not needed
  21. tf.import_graph_def(graph_def, name="prefix")
  22. return graph
  23.  
  24. def segment(graph, image_file):
  25. """
  26. Does the segmentation on the given image.
  27. Args:
  28. graph (Tensorflow Graph)
  29. image_file (str): Full path to your image
  30. Returns:
  31. segmentation_mask (np.array): The segmentation mask of the image.
  32. """
  33. # We access the input and output nodes
  34. x = graph.get_tensor_by_name('prefix/ImageTensor:0')
  35. y = graph.get_tensor_by_name('prefix/SemanticPredictions:0')
  36.  
  37. # We launch a Session
  38. with tf.Session(graph=graph) as sess:
  39.  
  40. image = Image.open(image_file)
  41. image = image.resize((299, 299))
  42. image_array = np.array(image)
  43. image_array = np.expand_dims(image_array, axis=0)
  44.  
  45. # Note: we don't nee to initialize/restore anything
  46. # There is no Variables in this graph, only hardcoded constants
  47. pred = sess.run(y, feed_dict={x: image_array})
  48.  
  49. pred = pred.squeeze()
  50.  
  51. return pred
  52.  
  53. def get_n_rgb_colors(n):
  54. """
  55. Get n evenly spaced RGB colors.
  56. Returns:
  57. rgb_colors (list): List of RGB colors.
  58. """
  59. max_value = 16581375 #255**3
  60. interval = int(max_value / n)
  61. colors = [hex(I)[2:].zfill(6) for I in range(0, max_value, interval)]
  62.  
  63. rgb_colors = [(int(i[:2], 16), int(i[2:4], 16), int(i[4:], 16)) for i in colors]
  64.  
  65. return rgb_colors
  66.  
  67. def parse_pred(pred, n_classes):
  68. """
  69. Parses a prediction and returns the prediction as a PIL.Image.
  70. Args:
  71. pred (np.array)
  72. Returns:
  73. parsed_pred (PIL.Image): Parsed prediction that we can view as an image.
  74. """
  75. uni = np.unique(pred)
  76.  
  77. empty = np.empty((pred.shape[1], pred.shape[0], 3))
  78.  
  79. colors = get_n_rgb_colors(n_classes)
  80.  
  81. for i, u in enumerate(uni):
  82. idx = np.transpose((pred == u).nonzero())
  83. c = colors[u]
  84. empty[idx[:,0], idx[:,1]] = [c[0],c[1],c[2]]
  85.  
  86. parsed_pred = np.array(empty, dtype=np.uint8)
  87. parsed_pred = Image.fromarray(parsed_pred)
  88.  
  89. return parsed_pred
  90.  
  91. if __name__ == '__main__':
  92. N_CLASSES = THE NUMBER OF CLASSES YOUR MODEL HAS
  93. MODEL_FILE = 'FULL PATH TO YOUR PB FILE'
  94. IMAGE_FILE = 'FULL PATH TO YOUR IMAGE'
  95.  
  96. graph = load_graph(MODEL_FILE)
  97. prediction = segment(graph, IMAGE_FILE)
  98. segmented_image = parse_pred(prediction, N_CLASSES)
  99.  
  100. segmented_image.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement