Advertisement
Guest User

Untitled

a guest
Jun 26th, 2019
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.26 KB | None | 0 0
  1. def beam_search(self, beam_size, enc_output, encoded, enc_lengths, nwords, length_limit):
  2. assert self.training == False
  3.  
  4. def repeat_and_reshape(x, ntimes=beam_size):
  5. # x: (bsz, ...) to (bsz * beam_size, ...)
  6. shape = list(x.shape)
  7. bsz = shape[0]
  8. shape_repeat = [1] * len(shape)
  9. shape_repeat.insert(1, ntimes)
  10. res = x.unsqueeze(1).repeat(shape_repeat).view(bsz * beam_size, *shape[1:])
  11. return res
  12.  
  13. bsz, seq_len, _ = enc_output.shape
  14. enc_output = repeat_and_reshape(enc_output)
  15. encoded = repeat_and_reshape(encoded)
  16. enc_lengths = repeat_and_reshape(enc_lengths)
  17. self.attenion.load_enc_output(enc_output, enc_lengths)
  18. hid = self.encoded2hid(encoded) # hid: (bsz * beam_size, hdim)
  19. inp_feed = enc_output.new_zeros(bsz, self.enc_odim)
  20. inp_feed = repeat_and_reshape(inp_feed)
  21. widx = torch.LongTensor([self.bos]).expand(bsz).to(enc_lengths)
  22. inp = self.embedding(widx)
  23. inp = repeat_and_reshape(inp)
  24. widx = repeat_and_reshape(widx).view(bsz, beam_size)
  25. widices = [widx] # wdix: (bsz, beam_size)
  26. finished = torch.zeros(bsz).to(enc_output).byte()
  27. finished = repeat_and_reshape(finished) # finished: (bsz * beam_size,)
  28. pred = [] # elem: (bsz, beam_size)
  29. datebacks = [] # elem: (bsz, beam_size)
  30. logprob_word_cum = torch.zeros(bsz).to(enc_output)
  31. logprob_word_cum = repeat_and_reshape(logprob_word_cum) # logprob_word_cum: (bsz * beam_size,)
  32. for t in range(length_limit):
  33. if finished.sum() == bsz * beam_size:
  34. break
  35. if t != 0: # here appending widx for dating back
  36. widices.append(widx)
  37. hid, inp_feed, logprob_word = self.transition(inp=inp,
  38. inp_feed=inp_feed,
  39. hid=hid)
  40. logprob_word_cur = logprob_word_cum.clone(). \
  41. masked_fill_(finished, 0). \
  42. view(-1, 1) + logprob_word
  43. # logprob_word_cur: (bsz * beam_size, nwords)
  44. # the `masked_fill_' eliminates influence from finished sequences.
  45.  
  46. logprob_word_expand, idx_expand = logprob_word_cur.view(bsz, -1).topk(k=beam_size, dim=-1)
  47. widx = idx_expand.fmod(nwords) # widx: (bsz, beam_size)
  48. dateback = idx_expand.div(nwords) # indicating from which beam
  49. logprob_word_cum = logprob_word_cum.masked_fill_((1 - finished), 0) + \
  50. logprob_word_expand.view(-1).masked_fill_(finished, 0)
  51. inp = self.embedding(widx.view(-1)) # inp: (bsz * beam_size, edim)
  52. datebacks.append(dateback)
  53.  
  54. finished = (finished + widx.view(-1).eq(self.eos)).gt(0)
  55. pred.insert(0, widx) # note here widx != widices[-1] (the final widx is not recorded by widices)
  56. # widices: -1, 0, 1, ..., t-2; datebacks: 0, 1, 2, ..., t-1
  57. for widx, dateback in zip(reversed(widices), reversed(datebacks)):
  58. # widx, dateback: (bsz, beam_size)
  59. widx_from = widx.gather(dim=-1, index=dateback)
  60. pred.insert(0, widx_from)
  61.  
  62. pred = torch.stack(pred, dim=-1) # pred: (bsz, beam_size, seq_len)
  63. _, indices_best = logprob_word_cum.view(bsz, beam_size).max(dim=1) # indices_best: (bsz, )
  64. pred = pred[range(bsz), indices_best]
  65. return pred
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement