Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- src_ph = tf.placeholder('int32', [None, None])
- prev_tokens = tf.placeholder_with_default([out_voc.bos_ix], [None])
- h0 = model.encode(src_ph)
- prev_state_ph = [tf.placeholder(var.dtype, var.shape) for var in h0]
- new_state, new_logits = model.decode(prev_state_ph, prev_tokens)
- def translate_line(src):
- _state = sess.run(h0, {src_ph: inp_voc.to_matrix([src])})
- tokens = [out_voc.bos_ix]
- output = []
- for i in range(100):
- output.append(out_voc.tokens[tokens[0]])
- feed = {prev_tokens: tokens}
- for var, val in zip(prev_state_ph, _state):
- feed[var] = val
- _state, _logits = sess.run([new_state, new_logits], feed)
- tokens = _logits.argmax(-1)
- if tokens[0] == model.out_voc.eos_ix:
- break
- return ' '.join(output)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement