Guest User

Untitled

a guest
Feb 16th, 2019
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.17 KB | None | 0 0
  1. import pandas as pd
  2. import torch.nn.functional as F
  3.  
  4. def sparse_bce_with_logits(x, i, j):
  5. t1 = x.clamp(min=0).mean()
  6. t2 = - x[(i, j)].sum() / x.numel()
  7. t3 = torch.log(1 + torch.exp(-torch.abs(x))).mean()
  8.  
  9. return t1 + t2 + t3
  10.  
  11. loss = torch.nn.BCEWithLogitsLoss()
  12. sloss = torch.nn.BCELoss()
  13.  
  14. all_res = []
  15. for scale in np.arange(0, 100, 2):
  16. x = torch.randn((100, 10)) * scale
  17. sx = torch.sigmoid(x)
  18. y = (torch.rand((100, 10)) < 0.1).float()
  19.  
  20. i, j = np.where(y.numpy())
  21. i, j = torch.LongTensor(i), torch.LongTensor(j)
  22.  
  23. bce_logits = loss(x, y)
  24. bce_sigmoid = sloss(sx, y)
  25. bce_sigmoid_manual = - (y * sx.log() + (1 - y) * (1 - sx).log()).mean()
  26. bce_logit_manual = (x.clamp(min=0) - x * y + torch.log(1 + torch.exp(-torch.abs(x)))).mean()
  27. bce_logit_sparse = sparse_bce_with_logits(x, i, j)
  28.  
  29. res = {
  30. "sigmoid" : (bce_logits - bce_sigmoid).item(),
  31. "sigmoid_manual" : (bce_logits - bce_sigmoid_manual).item(),
  32. "logit_manual" : (bce_logits - bce_logit_manual).item(),
  33. "logit_sparse" : (bce_logits - bce_logit_sparse).item(),
  34. }
  35. all_res.append(res)
  36.  
  37. pd.DataFrame(all_res)
Add Comment
Please, Sign In to add comment