Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class ContrastiveLoss(torch.nn.Module):
- def __init__(self, margin=1.0):
- super(ContrastiveLoss, self).__init__()
- self.margin = margin
- def forward(self, x, y):
- #<your code>
- L_sum = 0
- for i in range(x.size()[0] - 1):
- for j in range(i+1, x.size()[0]):
- y_ij = 1
- if y[i] == y[j]:
- y_ij = 0
- norm = torch.norm((x[i] - x[j]), p=2) ** 2
- if self.margin - norm > 0:
- L_sum += (1 - y_ij) * norm + y_ij * (self.margin - norm)
- else:
- L_sum += (1 - y_ij) * norm
- return 2 * L_sum / (x.shape[0] * (x.shape[0] - 1))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement