Advertisement
Guest User

Untitled

a guest
Feb 21st, 2017
66
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.45 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. """
  4. Pairwise distance module for pytorch.
  5. """
  6.  
  7. import torch.autograd as autograd
  8. import torch.nn as nn
  9.  
  10. class PairwiseDistance(nn.Module):
  11. def __init__(self, norm_type, dim=1):
  12. super(PairwiseDistance, self).__init__()
  13. self.norm_type = norm_type
  14. self.dim = dim
  15.  
  16. def forward(self, input):
  17. return autograd.variable.Norm(self.norm_type,
  18. self.dim)(input[1]-input[0])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement