SHARE
TWEET

Untitled

a guest Aug 26th, 2019 75 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import tensorflow as tf
  2. import os
  3. from networks.networks import *
  4. from common import get_checkpoint
  5.  
  6.  
  7. model_type = 'mv2_cpm'
  8. CHECKPOINT_DIR = './weights/'
  9. OUTPUT_NODE_NAMES = ['output']
  10. output_name = './frozen_models/cpm_352.pb'
  11.  
  12. OUTPUT_PATH = './'
  13.  
  14. if __name__ == '__main__':
  15.     output_graph = os.path.join(OUTPUT_PATH, output_name)
  16.  
  17.     tf_config = tf.ConfigProto()
  18.     tf_config.gpu_options.allow_growth = True
  19.     with tf.Session(config=tf_config) as sess:
  20.         # initialize model
  21.         net_in = tf.placeholder(tf.float32, [1, 352, 352, 3], name="input_image")
  22.         net_out, _ = get_network(model_type, net_in, trainable=False)
  23.  
  24.         saver = tf.train.Saver()
  25.         resumed_checkpoint, last_step = get_checkpoint(CHECKPOINT_DIR)
  26.         saver.restore(sess, resumed_checkpoint)
  27.  
  28.         output_graph_def = tf.graph_util.convert_variables_to_constants(
  29.             sess,
  30.             tf.get_default_graph().as_graph_def(),
  31.             OUTPUT_NODE_NAMES)
  32.  
  33.         output_graph_def = tf.graph_util.remove_training_nodes(output_graph_def, protected_nodes='output')
  34.  
  35.         # Finally we serialize and dump the output graph to the filesystem
  36.         with tf.gfile.GFile(output_graph, "wb") as f:
  37.             f.write(output_graph_def.SerializeToString())
  38.         print("%d ops in the final graph." % len(output_graph_def.node))
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top