Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import time
- import torch
- import itertools
- def make_pairs_ids(nregion, bsize):
- pairs_ids = []
- for batch_id in range(bsize):
- pairs_id = torch.LongTensor([
- (batch_id,i,j) for i,j in \
- itertools.product(range(nregion),repeat=2)])
- pairs_ids.append(pairs_id)
- out = torch.cat(pairs_ids).contiguous()
- return out
- if __name__ == '__main__':
- niter=10
- bsize=32
- nregion=36
- dimh=2048
- module = torch.nn.Linear(dimh, dimh)
- mm = torch.randn(bsize, nregion, dimh)
- pairs_ids = make_pairs_ids(nregion, bsize)
- module.cuda()
- mm = mm.cuda()
- pairs_ids = pairs_ids.cuda()
- mm = torch.autograd.Variable(mm, requires_grad=True)
- t = time.time()
- torch.cuda.synchronize()
- for i in range(niter):
- pair_mm = mm[pairs_ids[:,0][:,None], pairs_ids[:,1:]]
- outfusion = pair_mm[:,0,:] - pair_mm[:,1,:]
- out = module(outfusion)
- out.sum().backward()
- torch.cuda.synchronize()
- print(time.time() - t)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement