Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import operator
- from queue import PriorityQueue
- class BeamSearchNode(object):
- def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
- '''
- :param hiddenstate:
- :param previousNode:
- :param wordId:
- :param logProb:
- :param length:
- '''
- self.h = hiddenstate
- self.prevNode = previousNode
- self.wordid = wordId
- self.logp = logProb
- self.leng = length
- def eval(self, alpha=1.0):
- reward = 0
- # Add here a function for shaping a reward
- return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
- def beam_decode(target_tensor, decoder_hiddens, encoder_outputs, decoder):
- '''
- :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
- :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
- :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
- :return: decoded_words
- '''
- beam_size = 5
- beam_width = beam_size
- topk = 1 # how many sentence do you want to generate
- decoded_batch = []
- # decoding goes sentence by sentence
- for idx in range(target_tensor.size(0)):
- # Start with the start of the sentence token
- decoder_input = torch.LongTensor([[SOS_token]], device=device)
- # Number of sentence to generate
- endnodes = []
- number_required = min((topk + 1), topk - len(endnodes))
- # starting node - hidden vector, previous node, word id, logp, length
- node = BeamSearchNode(decoder_hiddens, None, decoder_input, 0, 1)
- nodes = PriorityQueue()
- # start the queue
- nodes.put((-node.eval(), node))
- qsize = 1
- # start beam search
- while True:
- # give up when decoding takes too long
- if qsize > 2000: break
- # fetch the best node
- score, n = nodes.get()
- decoder_input = n.wordid
- decoder_hidden = n.h
- if n.wordid.item() == EOS_token and n.prevNode != None:
- endnodes.append((score, n))
- # if we reached maximum # of sentences required
- if len(endnodes) >= number_required:
- break
- else:
- continue
- # decode for one step using decoder
- decoder_output, decoder_hidden, _ = decoder(decoder_input, decoder_hidden, encoder_outputs)
- # PUT HERE REAL BEAM SEARCH OF TOP
- log_prob, indexes = torch.topk(decoder_output, beam_width)
- nextnodes = []
- for new_k in range(beam_width):
- decoded_t = indexes[0][new_k].view(1, -1)
- log_p = log_prob[0][new_k].item()
- node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
- score = -node.eval()
- nextnodes.append((score, node))
- # put them into queue
- for i in range(len(nextnodes)):
- score, nn = nextnodes[i]
- nodes.put((score, nn))
- # increase qsize
- qsize += len(nextnodes) - 1
- # choose nbest paths, back trace them
- if len(endnodes) == 0:
- endnodes = [nodes.get() for _ in range(topk)]
- utterances = []
- for score, n in sorted(endnodes, key=operator.itemgetter(0)):
- utterance = []
- utterance.append(n.wordid)
- # back trace
- while n.prevNode != None:
- n = n.prevNode
- utterance.append(n.wordid)
- utterance = utterance[::-1]
- utterances.append(utterance)
- utterances = list(itertools.chain.from_iterable(utterances))
- decoded_words = [output_lang.index2word[word.item()] for word in utterances]
- return decoded_words
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement