Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def generate(model_path, reversed_dictionary, gen_word_count, predict_full_sentence=True,start_word=0):
- test_input = Input(batch_size=1, num_steps=num_steps, data=test_data)
- x = tf.placeholder(dtype=tf.int32, shape=[None, None])
- y = tf.placeholder(dtype=tf.int32, shape=[None, None])
- m = Model(x, y, input=test_input, is_training=False, vocab_size=vocabulary)
- saver = tf.train.Saver()
- with tf.Session() as sess:
- current_state_fw = np.zeros((NUM_LAYERS, 2, 1, HIDDEN_SIZE))
- current_state_bw = np.zeros((NUM_LAYERS, 2, 1, HIDDEN_SIZE))
- # restore the trained model
- saver.restore(sess, model_path)
- sentence = start_word # 0 is 'the'
- sentence = np.reshape(sentence, (-1, 1))
- for step in range(gen_word_count):
- if predict_full_sentence:
- pred, current_state_fw, current_state_bw = sess.run([m.predict, m.state_fw, m.state_bw],
- feed_dict={m.init_state_fw: current_state_fw,
- m.init_state_bw: current_state_bw,
- x: sentence})
- else:
- pred, current_state_fw, current_state_bw = sess.run([m.predict, m.state_fw, m.state_bw],
- feed_dict={m.init_state_fw: current_state_fw,
- m.init_state_bw: current_state_bw,
- x: np.reshape(sentence[:,-1], (-1, 1))})
- sentence = np.append(sentence, pred[-1])
- sentence = np.reshape(sentence, (1, -1))
- print(" ".join([reversed_dictionary[x] for x in sentence[0]]))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement