Advertisement
Guest User

Untitled

a guest
Feb 21st, 2018
130
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.73 KB | None | 0 0
  1. class ContrastiveLoss(torch.nn.Module):
  2. def __init__(self, margin=1.0):
  3. super(ContrastiveLoss, self).__init__()
  4. self.margin = margin
  5.  
  6. def forward(self, x, y):
  7. #<your code>
  8. L_sum = 0
  9. for i in range(x.size()[0] - 1):
  10. for j in range(i+1, x.size()[0]):
  11. y_ij = 1
  12. if y[i] == y[j]:
  13. y_ij = 0
  14. norm = torch.norm((x[i] - x[j]), p=2) ** 2
  15.  
  16. if self.margin - norm > 0:
  17. L_sum += (1 - y_ij) * norm + y_ij * (self.margin - norm)
  18. else:
  19. L_sum += (1 - y_ij) * norm
  20.  
  21. return 2 * L_sum / (x.shape[0] * (x.shape[0] - 1))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement