Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
- net.eval()
- state_h, state_c = net.zero_state(1)
- state_h = state_h.to(device)
- state_c = state_c.to(device)
- for w in words:
- ix = torch.tensor([[vocab_to_int[w]]]).to(device)
- output, (state_h, state_c) = net(ix, (state_h, state_c))
- _, top_ix = torch.topk(output[0], k=top_k)
- choices = top_ix.tolist()
- choice = np.random.choice(choices[0])
- words.append(int_to_vocab[choice])
Add Comment
Please, Sign In to add comment