Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- if beam_search: #beam search
- k = 10 #store best k options
- ## CONVERT BEST_OPTIONS WITH USE_GPU = TRUE OR FALSE.
- best_options = [(Variable(torch.zeros(1)).cuda(), Variable(torch.LongTensor([EN.vocab.stoi["<s>"]])).cuda(), states)]
- length = 0
- while length < max_trg_len:
- options = [] #same format as best_options
- for lprob, sentence, current_state in best_options:
- last_word = sentence[-1]
- if last_word.data[0] != EN.vocab.stoi["</s>"]:
- probs, new_state = decoder(last_word.unsqueeze(1), current_state, outputs)
- probs = probs.squeeze()
- for index in torch.topk(probs, k)[1]: #only care about top k options in probs for next word.
- options.append((torch.add(probs[index], lprob), torch.cat([sentence, index]), new_state))
- else:
- options.append((lprob, sentence, current_state))
- options.sort(key = lambda x: x[0].data[0], reverse=True)
- best_options = options[:k] #sorts by first element, which is lprob.
- length = length + 1
- best_options.sort(key = lambda x: x[0].data[0], reverse=True)
- best_choice = best_options[0] #best overall
- sentence = best_choice[1].data
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement