Advertisement
Guest User

Untitled

a guest
Aug 26th, 2019
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.30 KB | None | 0 0
  1. import tensorflow as tf
  2. import os
  3. from networks.networks import *
  4. from common import get_checkpoint
  5.  
  6.  
  7. model_type = 'mv2_cpm'
  8. CHECKPOINT_DIR = './weights/'
  9. OUTPUT_NODE_NAMES = ['output']
  10. output_name = './frozen_models/cpm_352.pb'
  11.  
  12. OUTPUT_PATH = './'
  13.  
  14. if __name__ == '__main__':
  15. output_graph = os.path.join(OUTPUT_PATH, output_name)
  16.  
  17. tf_config = tf.ConfigProto()
  18. tf_config.gpu_options.allow_growth = True
  19. with tf.Session(config=tf_config) as sess:
  20. # initialize model
  21. net_in = tf.placeholder(tf.float32, [1, 352, 352, 3], name="input_image")
  22. net_out, _ = get_network(model_type, net_in, trainable=False)
  23.  
  24. saver = tf.train.Saver()
  25. resumed_checkpoint, last_step = get_checkpoint(CHECKPOINT_DIR)
  26. saver.restore(sess, resumed_checkpoint)
  27.  
  28. output_graph_def = tf.graph_util.convert_variables_to_constants(
  29. sess,
  30. tf.get_default_graph().as_graph_def(),
  31. OUTPUT_NODE_NAMES)
  32.  
  33. output_graph_def = tf.graph_util.remove_training_nodes(output_graph_def, protected_nodes='output')
  34.  
  35. # Finally we serialize and dump the output graph to the filesystem
  36. with tf.gfile.GFile(output_graph, "wb") as f:
  37. f.write(output_graph_def.SerializeToString())
  38. print("%d ops in the final graph." % len(output_graph_def.node))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement