Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def beam_search(self, beam_size, enc_output, encoded, enc_lengths, nwords, length_limit):
- assert self.training == False
- def repeat_and_reshape(x, ntimes=beam_size):
- # x: (bsz, ...) to (bsz * beam_size, ...)
- shape = list(x.shape)
- bsz = shape[0]
- shape_repeat = [1] * len(shape)
- shape_repeat.insert(1, ntimes)
- res = x.unsqueeze(1).repeat(shape_repeat).view(bsz * beam_size, *shape[1:])
- return res
- bsz, seq_len, _ = enc_output.shape
- enc_output = repeat_and_reshape(enc_output)
- encoded = repeat_and_reshape(encoded)
- enc_lengths = repeat_and_reshape(enc_lengths)
- self.attenion.load_enc_output(enc_output, enc_lengths)
- hid = self.encoded2hid(encoded) # hid: (bsz * beam_size, hdim)
- inp_feed = enc_output.new_zeros(bsz, self.enc_odim)
- inp_feed = repeat_and_reshape(inp_feed)
- widx = torch.LongTensor([self.bos]).expand(bsz).to(enc_lengths)
- inp = self.embedding(widx)
- inp = repeat_and_reshape(inp)
- widx = repeat_and_reshape(widx).view(bsz, beam_size)
- widices = [widx] # wdix: (bsz, beam_size)
- finished = torch.zeros(bsz).to(enc_output).byte()
- finished = repeat_and_reshape(finished) # finished: (bsz * beam_size,)
- pred = [] # elem: (bsz, beam_size)
- datebacks = [] # elem: (bsz, beam_size)
- logprob_word_cum = torch.zeros(bsz).to(enc_output)
- logprob_word_cum = repeat_and_reshape(logprob_word_cum) # logprob_word_cum: (bsz * beam_size,)
- for t in range(length_limit):
- if finished.sum() == bsz * beam_size:
- break
- if t != 0: # here appending widx for dating back
- widices.append(widx)
- hid, inp_feed, logprob_word = self.transition(inp=inp,
- inp_feed=inp_feed,
- hid=hid)
- logprob_word_cur = logprob_word_cum.clone(). \
- masked_fill_(finished, 0). \
- view(-1, 1) + logprob_word
- # logprob_word_cur: (bsz * beam_size, nwords)
- # the `masked_fill_' eliminates influence from finished sequences.
- logprob_word_expand, idx_expand = logprob_word_cur.view(bsz, -1).topk(k=beam_size, dim=-1)
- widx = idx_expand.fmod(nwords) # widx: (bsz, beam_size)
- dateback = idx_expand.div(nwords) # indicating from which beam
- logprob_word_cum = logprob_word_cum.masked_fill_((1 - finished), 0) + \
- logprob_word_expand.view(-1).masked_fill_(finished, 0)
- inp = self.embedding(widx.view(-1)) # inp: (bsz * beam_size, edim)
- datebacks.append(dateback)
- finished = (finished + widx.view(-1).eq(self.eos)).gt(0)
- pred.insert(0, widx) # note here widx != widices[-1] (the final widx is not recorded by widices)
- # widices: -1, 0, 1, ..., t-2; datebacks: 0, 1, 2, ..., t-1
- for widx, dateback in zip(reversed(widices), reversed(datebacks)):
- # widx, dateback: (bsz, beam_size)
- widx_from = widx.gather(dim=-1, index=dateback)
- pred.insert(0, widx_from)
- pred = torch.stack(pred, dim=-1) # pred: (bsz, beam_size, seq_len)
- _, indices_best = logprob_word_cum.view(bsz, beam_size).max(dim=1) # indices_best: (bsz, )
- pred = pred[range(bsz), indices_best]
- return pred
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement