yulmen_

Untitled

May 23rd, 2021
188
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 0.67 KB | None | 0 0
  1. def simple_predict(model, text, length_to_predict):
  2.     targets = []
  3.  
  4.     bpe_string = bpe.encode(text, output_type=yttm.OutputType.SUBWORD, bos=True, eos=False)
  5.    
  6.     bpe_string_id = []
  7.     for s in bpe_string:
  8.         bpe_string_id.append(bpe.subword_to_id(s))
  9.  
  10.     for i in range(length_to_predict - len(bpe_string)):
  11.         x = torch.LongTensor([bpe_string_id])
  12.         predict = model(x)[0]
  13.         pred1 = torch.argmax(torch.sigmoid(predict.cpu().contiguous().view(-1, num_tokens)), dim=1)[-1]
  14.         bpe_string_id.append(pred1)
  15.         bpe_string.append(bpe.id_to_subword(pred1))
  16.         if bpe_string[-1] == '<EOS>':
  17.             break
  18.  
  19.     return bpe_string_id
Advertisement
Add Comment
Please, Sign In to add comment