Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf, sys
- import simplejson as json
- from flask import Flask, jsonify
- app = Flask(__name__)
- def start_tf_session():
- # Loads label file, strips off carriage return
- label_lines = [line.rstrip() for line
- in tf.gfile.GFile("tfmodel/retrained_labels_cats.txt")]
- # Unpersists graph from file
- with tf.gfile.FastGFile("tfmodel/retrained_graph_cats.pb", 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- _ = tf.import_graph_def(graph_def, name='')
- print('Tensorflow session (for cats) created.')
- return tf.Session()
- # Create the "Tensor"
- def get_tensor(sess):
- print('softmax_tensor (for cats) got gotten.')
- return sess.graph.get_tensor_by_name('final_result:0')
- # This function gets called by the web service. It runs the Tensor model.
- def image_labels(image_name):
- # Load the image from disk
- print('loading image')
- image_data = tf.gfile.FastGFile(image_name, 'rb').read()
- # Use TensorFlow!
- print('running prediction')
- predictions = sess.run(softmax_tensor, \
- {'DecodeJpeg/contents:0': image_data})
- print('sorting and organizing results')
- # Sort to show labels of first prediction in order of confidence
- top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
- results = dict()
- # Save the results from the TensorFlow function as a Dictionary
- for node_id in top_k:
- human_string = label_lines[node_id]
- score = predictions[0][node_id]
- print('%s (score = %.5f)' % (human_string, score))
- results[human_string] = str(round(score,5))
- return results
- # Start the app!
- print('Starting Tensorflow (for cats)')
- # Start a TensorFlow session, and save the session as a variable
- sess = start_tf_session()
- # Get a reference to the softmax tensor with the newly created Session
- softmax_tensor = get_tensor(sess)
- # Load the labels file from the custom trained model
- label_lines = [line.rstrip() for line in tf.gfile.GFile("tfmodel/retrained_labels_cats.txt")]
- # Create a web service to analyze the raspberry pi image
- @app.route("/lola")
- def lola():
- print('Analyzing image for Lola.')
- # Get the labels associated with the input image
- labels = image_labels("rpicam.jpg")
- # Return a JSON object of labels and probabilities
- print('returning json')
- return jsonify(**labels)
- # Run Flask web service
- if __name__ == "__main__":
- app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement