Guest User

Untitled

a guest
Jun 19th, 2018
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.34 KB | None | 0 0
  1. from __future__ import print_function
  2. from __future__ import absolute_import
  3.  
  4. import tensorflow as tf
  5. from nets.mobilenet_v1 import mobilenet_v1, mobilenet_v1_arg_scope
  6. from tensorflow.python.framework import graph_util
  7. import sys
  8.  
  9. slim = tf.contrib.slim
  10.  
  11. checkpoint_file = tf.train.latest_checkpoint('/tmp/flowers-models/mobilenet_v1')
  12.  
  13. with tf.Graph().as_default() as graph:
  14. images = tf.placeholder(shape=[None, 224, 224, 3], dtype=tf.float32, name = 'input')
  15.  
  16. with slim.arg_scope(mobilenet_v1_arg_scope()):
  17. logits, end_points = mobilenet_v1(images, num_classes=5, is_training=False)
  18.  
  19. variables_to_restore = slim.get_variables_to_restore()
  20. saver = tf.train.Saver(variables_to_restore)
  21.  
  22. input_graph_def = graph.as_graph_def()
  23. output_node_names = 'MobilenetV1/Predictions/Reshape_1'
  24. output_graph_name = '/notebooks/test.pb'
  25.  
  26. with tf.Session() as sess:
  27. saver.restore(sess, checkpoint_file)
  28.  
  29. print("exporting graph...")
  30. output_graph_def = graph_util.convert_variables_to_constants(
  31. sess,
  32. input_graph_def,
  33. output_node_names.split(","))
  34.  
  35. with tf.gfile.GFile(output_graph_name, "wb") as f:
  36. f.write(output_graph_def.SerializeToString())
Add Comment
Please, Sign In to add comment