Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import os
- from networks.networks import *
- from common import get_checkpoint
- model_type = 'mv2_cpm'
- CHECKPOINT_DIR = './weights/'
- OUTPUT_NODE_NAMES = ['output']
- output_name = './frozen_models/cpm_352.pb'
- OUTPUT_PATH = './'
- if __name__ == '__main__':
- output_graph = os.path.join(OUTPUT_PATH, output_name)
- tf_config = tf.ConfigProto()
- tf_config.gpu_options.allow_growth = True
- with tf.Session(config=tf_config) as sess:
- # initialize model
- net_in = tf.placeholder(tf.float32, [1, 352, 352, 3], name="input_image")
- net_out, _ = get_network(model_type, net_in, trainable=False)
- saver = tf.train.Saver()
- resumed_checkpoint, last_step = get_checkpoint(CHECKPOINT_DIR)
- saver.restore(sess, resumed_checkpoint)
- output_graph_def = tf.graph_util.convert_variables_to_constants(
- sess,
- tf.get_default_graph().as_graph_def(),
- OUTPUT_NODE_NAMES)
- output_graph_def = tf.graph_util.remove_training_nodes(output_graph_def, protected_nodes='output')
- # Finally we serialize and dump the output graph to the filesystem
- with tf.gfile.GFile(output_graph, "wb") as f:
- f.write(output_graph_def.SerializeToString())
- print("%d ops in the final graph." % len(output_graph_def.node))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement