Guest User

Untitled

a guest
Mar 18th, 2018
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.47 KB | None | 0 0
  1. import torch as th
  2. from torch.autograd import Variable
  3.  
  4. def linear_slice():
  5. """
  6. We want to fech data points from "src_data" and put them in "output" so
  7. that output[i] ~ src_data[lookup_coord[i]], where lookup_coord is an array
  8. specifying the location of the requested points.
  9.  
  10. The values in "src_data" are linearly interpolated, so the mapping is
  11. differentiable a.e.
  12. """
  13.  
  14. dst_sz = 8
  15. src_sz = 4
  16.  
  17. # Some random data we want to interpolate from
  18. src_data = Variable(th.rand(src_sz), requires_grad=True)
  19.  
  20. # interpolation coordinate w.r.t. src_data, in [0, src_sz[
  21. lookup_coord = Variable(th.rand(dst_sz) * src_sz, requires_grad=True)
  22.  
  23. # lower coord upper
  24. # -------+--------*------+---
  25. lower_bin = th.clamp(th.floor(lookup_coord-0.5), min=0)
  26. upper_bin = th.clamp(lower_bin+1, max=src_sz-1) # make sure we're in bounds
  27.  
  28. # Linear interpolation weight
  29. weight = th.abs(lookup_coord-0.5 - lower_bin)
  30.  
  31. # Make the coordinates integers to allow indexing
  32. lower_bin = lower_bin.long()
  33. upper_bin = upper_bin.long()
  34.  
  35. # Interpolate the data from src_data
  36. output = src_data[lower_bin]*(1.0 - weight) + src_data[upper_bin]*weight
  37.  
  38. # Backprop
  39. loss = output.sum()
  40. loss.backward()
  41.  
  42. # Check the gradients
  43. print(src_data.grad)
  44. print(lookup_coord.grad)
  45.  
  46. # We also want to write at locations indexed by an array
  47. data_copy = src_data + 1
  48. data_copy[lower_bin] += 1.0
  49. loss = data_copy.sum()
  50. loss.backward()
  51.  
  52.  
  53. if __name__ == "__main__":
  54. linear_slice()
Add Comment
Please, Sign In to add comment