Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import numpy as np
- from PIL import Image
- def load_graph(frozen_graph_filename):
- """
- Args:
- frozen_graph_filename (str): Full path to the .pb file.
- """
- # We load the protobuf file from the disk and parse it to retrieve the
- # unserialized graph_def
- with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- # Then, we import the graph_def into a new Graph and returns it
- with tf.Graph().as_default() as graph:
- # The name var will prefix every op/nodes in your graph
- # Since we load everything in a new graph, this is not needed
- tf.import_graph_def(graph_def, name="prefix")
- return graph
- def segment(graph, image_file):
- """
- Does the segmentation on the given image.
- Args:
- graph (Tensorflow Graph)
- image_file (str): Full path to your image
- Returns:
- segmentation_mask (np.array): The segmentation mask of the image.
- """
- # We access the input and output nodes
- x = graph.get_tensor_by_name('prefix/ImageTensor:0')
- y = graph.get_tensor_by_name('prefix/SemanticPredictions:0')
- # We launch a Session
- with tf.Session(graph=graph) as sess:
- image = Image.open(image_file)
- image = image.resize((299, 299))
- image_array = np.array(image)
- image_array = np.expand_dims(image_array, axis=0)
- # Note: we don't nee to initialize/restore anything
- # There is no Variables in this graph, only hardcoded constants
- pred = sess.run(y, feed_dict={x: image_array})
- pred = pred.squeeze()
- return pred
- def get_n_rgb_colors(n):
- """
- Get n evenly spaced RGB colors.
- Returns:
- rgb_colors (list): List of RGB colors.
- """
- max_value = 16581375 #255**3
- interval = int(max_value / n)
- colors = [hex(I)[2:].zfill(6) for I in range(0, max_value, interval)]
- rgb_colors = [(int(i[:2], 16), int(i[2:4], 16), int(i[4:], 16)) for i in colors]
- return rgb_colors
- def parse_pred(pred, n_classes):
- """
- Parses a prediction and returns the prediction as a PIL.Image.
- Args:
- pred (np.array)
- Returns:
- parsed_pred (PIL.Image): Parsed prediction that we can view as an image.
- """
- uni = np.unique(pred)
- empty = np.empty((pred.shape[1], pred.shape[0], 3))
- colors = get_n_rgb_colors(n_classes)
- for i, u in enumerate(uni):
- idx = np.transpose((pred == u).nonzero())
- c = colors[u]
- empty[idx[:,0], idx[:,1]] = [c[0],c[1],c[2]]
- parsed_pred = np.array(empty, dtype=np.uint8)
- parsed_pred = Image.fromarray(parsed_pred)
- return parsed_pred
- if __name__ == '__main__':
- N_CLASSES = THE NUMBER OF CLASSES YOUR MODEL HAS
- MODEL_FILE = 'FULL PATH TO YOUR PB FILE'
- IMAGE_FILE = 'FULL PATH TO YOUR IMAGE'
- graph = load_graph(MODEL_FILE)
- prediction = segment(graph, IMAGE_FILE)
- segmented_image = parse_pred(prediction, N_CLASSES)
- segmented_image.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement