Advertisement
Guest User

Untitled

a guest
Jul 23rd, 2019
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.75 KB | None | 0 0
  1. def generate(model_path, reversed_dictionary, gen_word_count, predict_full_sentence=True,start_word=0):
  2. test_input = Input(batch_size=1, num_steps=num_steps, data=test_data)
  3.  
  4. x = tf.placeholder(dtype=tf.int32, shape=[None, None])
  5. y = tf.placeholder(dtype=tf.int32, shape=[None, None])
  6.  
  7. m = Model(x, y, input=test_input, is_training=False, vocab_size=vocabulary)
  8. saver = tf.train.Saver()
  9. with tf.Session() as sess:
  10. current_state_fw = np.zeros((NUM_LAYERS, 2, 1, HIDDEN_SIZE))
  11. current_state_bw = np.zeros((NUM_LAYERS, 2, 1, HIDDEN_SIZE))
  12.  
  13. # restore the trained model
  14. saver.restore(sess, model_path)
  15.  
  16. sentence = start_word # 0 is 'the'
  17. sentence = np.reshape(sentence, (-1, 1))
  18. for step in range(gen_word_count):
  19. if predict_full_sentence:
  20. pred, current_state_fw, current_state_bw = sess.run([m.predict, m.state_fw, m.state_bw],
  21. feed_dict={m.init_state_fw: current_state_fw,
  22. m.init_state_bw: current_state_bw,
  23. x: sentence})
  24. else:
  25.  
  26. pred, current_state_fw, current_state_bw = sess.run([m.predict, m.state_fw, m.state_bw],
  27. feed_dict={m.init_state_fw: current_state_fw,
  28. m.init_state_bw: current_state_bw,
  29. x: np.reshape(sentence[:,-1], (-1, 1))})
  30. sentence = np.append(sentence, pred[-1])
  31. sentence = np.reshape(sentence, (1, -1))
  32. print(" ".join([reversed_dictionary[x] for x in sentence[0]]))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement