Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import pandas as pd
- import torch.nn.functional as F
- def sparse_bce_with_logits(x, i, j):
- t1 = x.clamp(min=0).mean()
- t2 = - x[(i, j)].sum() / x.numel()
- t3 = torch.log(1 + torch.exp(-torch.abs(x))).mean()
- return t1 + t2 + t3
- loss = torch.nn.BCEWithLogitsLoss()
- sloss = torch.nn.BCELoss()
- all_res = []
- for scale in np.arange(0, 100, 2):
- x = torch.randn((100, 10)) * scale
- sx = torch.sigmoid(x)
- y = (torch.rand((100, 10)) < 0.1).float()
- i, j = np.where(y.numpy())
- i, j = torch.LongTensor(i), torch.LongTensor(j)
- bce_logits = loss(x, y)
- bce_sigmoid = sloss(sx, y)
- bce_sigmoid_manual = - (y * sx.log() + (1 - y) * (1 - sx).log()).mean()
- bce_logit_manual = (x.clamp(min=0) - x * y + torch.log(1 + torch.exp(-torch.abs(x)))).mean()
- bce_logit_sparse = sparse_bce_with_logits(x, i, j)
- res = {
- "sigmoid" : (bce_logits - bce_sigmoid).item(),
- "sigmoid_manual" : (bce_logits - bce_sigmoid_manual).item(),
- "logit_manual" : (bce_logits - bce_logit_manual).item(),
- "logit_sparse" : (bce_logits - bce_logit_sparse).item(),
- }
- all_res.append(res)
- pd.DataFrame(all_res)
Add Comment
Please, Sign In to add comment