Advertisement
Guest User

Untitled

a guest
Nov 15th, 2018
139
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 0.75 KB | None | 0 0
  1. src_ph = tf.placeholder('int32', [None, None])
  2. prev_tokens = tf.placeholder_with_default([out_voc.bos_ix], [None])
  3. h0 = model.encode(src_ph)
  4. prev_state_ph = [tf.placeholder(var.dtype, var.shape) for var in h0]
  5. new_state, new_logits = model.decode(prev_state_ph, prev_tokens)
  6.  
  7. def translate_line(src):
  8.   _state = sess.run(h0, {src_ph: inp_voc.to_matrix([src])})
  9.   tokens = [out_voc.bos_ix]
  10.   output = []
  11.   for i in range(100):
  12.     output.append(out_voc.tokens[tokens[0]])
  13.     feed = {prev_tokens: tokens}
  14.     for var, val in zip(prev_state_ph, _state):
  15.       feed[var] = val
  16.     _state, _logits = sess.run([new_state, new_logits], feed)
  17.     tokens = _logits.argmax(-1)
  18.     if tokens[0] == model.out_voc.eos_ix:
  19.       break
  20.   return ' '.join(output)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement