Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def make_positions(self, input):
- """Replace non-padding symbols with their position numbers."""
- if not hasattr(self, 'range_buf'):
- self.range_buf = input.new()
- seqlen = input.size(1)
- if self.range_buf.numel() < seqlen:
- # offset positions by the padding index
- torch.arange(self.padding_idx + 1, self.padding_idx + 1 + seqlen,
- out=self.range_buf)
- mask = input.ne(self.padding_idx)
- positions = self.range_buf[:seqlen].expand_as(input)
- if self.left_pad:
- positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
- return input.clone().masked_scatter_(mask, positions[mask])
- '''
- import torch
- a= torch.arange(0, 10)
- tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
- inp = torch.tensor([-1, -1, -1, 1,2,3,4,5,6,7])
- inp
- tensor([-1, -1, -1, 1, 2, 3, 4, 5, 6, 7])
- mask = inp.ne(-1)
- mask
- tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8)
- a[mask]
- tensor([3, 4, 5, 6, 7, 8, 9])
- inp.clone().masked_scatter_(mask, a)
- tensor([-1, -1, -1, 0, 1, 2, 3, 4, 5, 6])
- inp.clone().masked_scatter_(mask, a[mask])
- tensor([-1, -1, -1, 3, 4, 5, 6, 7, 8, 9])
- mask.size(0)
- 10
- a - mask.size(0) + mask.sum()
- tensor([-3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
- b = a - mask.size(0) + mask.sum()
- b
- tensor([-3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
- inp.clone().masked_scatter_(mask, b[mask])
- tensor([-1, -1, -1, 0, 1, 2, 3, 4, 5, 6])
- '''
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement