Advertisement
Guest User

Untitled

a guest
Jul 17th, 2019
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.44 KB | None | 0 0
  1. def make_positions(self, input):
  2. """Replace non-padding symbols with their position numbers."""
  3. if not hasattr(self, 'range_buf'):
  4. self.range_buf = input.new()
  5. seqlen = input.size(1)
  6. if self.range_buf.numel() < seqlen:
  7. # offset positions by the padding index
  8. torch.arange(self.padding_idx + 1, self.padding_idx + 1 + seqlen,
  9. out=self.range_buf)
  10. mask = input.ne(self.padding_idx)
  11. positions = self.range_buf[:seqlen].expand_as(input)
  12. if self.left_pad:
  13. positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
  14. return input.clone().masked_scatter_(mask, positions[mask])
  15.  
  16.  
  17. '''
  18. import torch
  19.  
  20. a= torch.arange(0, 10)
  21. tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
  22.  
  23. inp = torch.tensor([-1, -1, -1, 1,2,3,4,5,6,7])
  24. inp
  25. tensor([-1, -1, -1, 1, 2, 3, 4, 5, 6, 7])
  26.  
  27. mask = inp.ne(-1)
  28. mask
  29. tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8)
  30.  
  31. a[mask]
  32. tensor([3, 4, 5, 6, 7, 8, 9])
  33.  
  34. inp.clone().masked_scatter_(mask, a)
  35. tensor([-1, -1, -1, 0, 1, 2, 3, 4, 5, 6])
  36.  
  37. inp.clone().masked_scatter_(mask, a[mask])
  38. tensor([-1, -1, -1, 3, 4, 5, 6, 7, 8, 9])
  39.  
  40. mask.size(0)
  41. 10
  42.  
  43. a - mask.size(0) + mask.sum()
  44. tensor([-3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
  45.  
  46. b = a - mask.size(0) + mask.sum()
  47. b
  48. tensor([-3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
  49.  
  50. inp.clone().masked_scatter_(mask, b[mask])
  51. tensor([-1, -1, -1, 0, 1, 2, 3, 4, 5, 6])
  52.  
  53. '''
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement