Advertisement
Guest User

Untitled

a guest
Feb 25th, 2018
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.41 KB | None | 0 0
  1. if beam_search: #beam search
  2. k = 10 #store best k options
  3. ## CONVERT BEST_OPTIONS WITH USE_GPU = TRUE OR FALSE.
  4. best_options = [(Variable(torch.zeros(1)).cuda(), Variable(torch.LongTensor([EN.vocab.stoi["<s>"]])).cuda(), states)]
  5. length = 0
  6. while length < max_trg_len:
  7. options = [] #same format as best_options
  8. for lprob, sentence, current_state in best_options:
  9. last_word = sentence[-1]
  10. if last_word.data[0] != EN.vocab.stoi["</s>"]:
  11. probs, new_state = decoder(last_word.unsqueeze(1), current_state, outputs)
  12. probs = probs.squeeze()
  13. for index in torch.topk(probs, k)[1]: #only care about top k options in probs for next word.
  14. options.append((torch.add(probs[index], lprob), torch.cat([sentence, index]), new_state))
  15. else:
  16. options.append((lprob, sentence, current_state))
  17. options.sort(key = lambda x: x[0].data[0], reverse=True)
  18. best_options = options[:k] #sorts by first element, which is lprob.
  19. length = length + 1
  20. best_options.sort(key = lambda x: x[0].data[0], reverse=True)
  21. best_choice = best_options[0] #best overall
  22. sentence = best_choice[1].data
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement