Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import fire
- import json
- import os
- import numpy as np
- import tensorflow as tf
- import model, sample, encoder
- seed=None
- length=40
- temperature=1
- top_k=0
- hparams = model.default_hparams()
- with open('models/345M/hparams.json') as f:
- hparams.override_from_dict(json.load(f))
- with tf.Session(graph=tf.Graph()) as sess:
- context = tf.placeholder(tf.int32, [1, None])
- np.random.seed(seed)
- tf.set_random_seed(seed)
- output = sample.sample_sequence(
- hparams=hparams, length=length,
- context=context,
- batch_size=1,
- temperature=temperature, top_k=top_k
- )
- saver = tf.train.Saver()
- ckpt = tf.train.latest_checkpoint(os.path.join('models', '345M'))
- saver.restore(sess, ckpt)
- print([n.name for n in tf.get_default_graph().as_graph_def().node])
- # Freeze the graph
- frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,[output.name])
- # Save the frozen graph
- with open('output_graph.pb', 'wb') as f:
- f.write(frozen_graph_def.SerializeToString())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement