Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def simple_predict(model, text, length_to_predict):
- targets = []
- bpe_string = bpe.encode(text, output_type=yttm.OutputType.SUBWORD, bos=True, eos=False)
- bpe_string_id = []
- for s in bpe_string:
- bpe_string_id.append(bpe.subword_to_id(s))
- for i in range(length_to_predict - len(bpe_string)):
- x = torch.LongTensor([bpe_string_id])
- predict = model(x)[0]
- pred1 = torch.argmax(torch.sigmoid(predict.cpu().contiguous().view(-1, num_tokens)), dim=1)[-1]
- bpe_string_id.append(pred1)
- bpe_string.append(bpe.id_to_subword(pred1))
- if bpe_string[-1] == '<EOS>':
- break
- return bpe_string_id
Advertisement
Add Comment
Please, Sign In to add comment