Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import print_function
- from __future__ import absolute_import
- import tensorflow as tf
- from nets.mobilenet_v1 import mobilenet_v1, mobilenet_v1_arg_scope
- from tensorflow.python.framework import graph_util
- import sys
- slim = tf.contrib.slim
- checkpoint_file = tf.train.latest_checkpoint('/tmp/flowers-models/mobilenet_v1')
- with tf.Graph().as_default() as graph:
- images = tf.placeholder(shape=[None, 224, 224, 3], dtype=tf.float32, name = 'input')
- with slim.arg_scope(mobilenet_v1_arg_scope()):
- logits, end_points = mobilenet_v1(images, num_classes=5, is_training=False)
- variables_to_restore = slim.get_variables_to_restore()
- saver = tf.train.Saver(variables_to_restore)
- input_graph_def = graph.as_graph_def()
- output_node_names = 'MobilenetV1/Predictions/Reshape_1'
- output_graph_name = '/notebooks/test.pb'
- with tf.Session() as sess:
- saver.restore(sess, checkpoint_file)
- print("exporting graph...")
- output_graph_def = graph_util.convert_variables_to_constants(
- sess,
- input_graph_def,
- output_node_names.split(","))
- with tf.gfile.GFile(output_graph_name, "wb") as f:
- f.write(output_graph_def.SerializeToString())
Add Comment
Please, Sign In to add comment