Advertisement
Guest User

Untitled

a guest
Jun 16th, 2019
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.00 KB | None | 0 0
  1. import fire
  2. import json
  3. import os
  4. import numpy as np
  5. import tensorflow as tf
  6.  
  7. import model, sample, encoder
  8.  
  9. seed=None
  10. length=40
  11. temperature=1
  12. top_k=0
  13.  
  14. hparams = model.default_hparams()
  15. with open('models/345M/hparams.json') as f:
  16. hparams.override_from_dict(json.load(f))
  17.  
  18. with tf.Session(graph=tf.Graph()) as sess:
  19. context = tf.placeholder(tf.int32, [1, None])
  20. np.random.seed(seed)
  21. tf.set_random_seed(seed)
  22. output = sample.sample_sequence(
  23. hparams=hparams, length=length,
  24. context=context,
  25. batch_size=1,
  26. temperature=temperature, top_k=top_k
  27. )
  28.  
  29. saver = tf.train.Saver()
  30. ckpt = tf.train.latest_checkpoint(os.path.join('models', '345M'))
  31. saver.restore(sess, ckpt)
  32.  
  33. print([n.name for n in tf.get_default_graph().as_graph_def().node])
  34.  
  35. # Freeze the graph
  36. frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,[output.name])
  37.  
  38. # Save the frozen graph
  39. with open('output_graph.pb', 'wb') as f:
  40. f.write(frozen_graph_def.SerializeToString())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement