Guest User

Untitled

a guest
Feb 15th, 2019
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.50 KB | None | 0 0
  1. def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
  2. net.eval()
  3.  
  4. state_h, state_c = net.zero_state(1)
  5. state_h = state_h.to(device)
  6. state_c = state_c.to(device)
  7. for w in words:
  8. ix = torch.tensor([[vocab_to_int[w]]]).to(device)
  9. output, (state_h, state_c) = net(ix, (state_h, state_c))
  10.  
  11. _, top_ix = torch.topk(output[0], k=top_k)
  12. choices = top_ix.tolist()
  13. choice = np.random.choice(choices[0])
  14.  
  15. words.append(int_to_vocab[choice])
Add Comment
Please, Sign In to add comment