Advertisement
Guest User

Untitled

a guest
Aug 20th, 2019
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.99 KB | None | 0 0
  1. import time
  2. import torch
  3. import itertools
  4.  
  5. def make_pairs_ids(nregion, bsize):
  6. pairs_ids = []
  7. for batch_id in range(bsize):
  8. pairs_id = torch.LongTensor([
  9. (batch_id,i,j) for i,j in \
  10. itertools.product(range(nregion),repeat=2)])
  11. pairs_ids.append(pairs_id)
  12. out = torch.cat(pairs_ids).contiguous()
  13. return out
  14.  
  15. if __name__ == '__main__':
  16. niter=10
  17.  
  18. bsize=32
  19. nregion=36
  20. dimh=2048
  21.  
  22. module = torch.nn.Linear(dimh, dimh)
  23.  
  24. mm = torch.randn(bsize, nregion, dimh)
  25. pairs_ids = make_pairs_ids(nregion, bsize)
  26.  
  27. module.cuda()
  28. mm = mm.cuda()
  29. pairs_ids = pairs_ids.cuda()
  30.  
  31. mm = torch.autograd.Variable(mm, requires_grad=True)
  32.  
  33. t = time.time()
  34. torch.cuda.synchronize()
  35.  
  36. for i in range(niter):
  37. pair_mm = mm[pairs_ids[:,0][:,None], pairs_ids[:,1:]]
  38.  
  39. outfusion = pair_mm[:,0,:] - pair_mm[:,1,:]
  40. out = module(outfusion)
  41. out.sum().backward()
  42.  
  43. torch.cuda.synchronize()
  44. print(time.time() - t)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement