Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- """
- Pairwise distance module for pytorch.
- """
- import torch.autograd as autograd
- import torch.nn as nn
- class PairwiseDistance(nn.Module):
- def __init__(self, norm_type, dim=1):
- super(PairwiseDistance, self).__init__()
- self.norm_type = norm_type
- self.dim = dim
- def forward(self, input):
- return autograd.variable.Norm(self.norm_type,
- self.dim)(input[1]-input[0])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement