Advertisement
Guest User

test_network.py

a guest
Jul 20th, 2017
48
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.98 KB | None | 0 0
  1. import sys
  2. import os
  3. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  4.  
  5. import tensorflow as tf
  6.  
  7.  
  8. label_lines = [line.rstrip() for line
  9. in tf.gfile.GFile(
  10. '/tf_files/retrained_labels.txt')]
  11.  
  12. with tf.gfile.FastGFile('/tf_files/retrained_graph.pb', 'rb') as f:
  13. graph_def = tf.GraphDef()
  14. graph_def.ParseFromString(f.read())
  15. _ = tf.import_graph_def(graph_def, name='')
  16.  
  17. sess = tf.Session()
  18. softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
  19.  
  20. test_dir = '/tf_files/Test'
  21. for img in os.listdir(test_dir):
  22. image_data = tf.gfile.FastGFile(os.path.join(test_dir, img), 'rb').read()
  23. predictions = sess.run(softmax_tensor,
  24. {'DecodeJpeg/contents:0': image_data})
  25.  
  26. top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
  27.  
  28. print(img)
  29. for node_id in top_k[:5]:
  30. human_string = label_lines[node_id]
  31. score = predictions[0][node_id]
  32. print(' %s (score = %.5f)' % (human_string, score))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement