Advertisement
Guest User

Untitled

a guest
Jul 22nd, 2019
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.60 KB | None | 0 0
  1. class FocalLoss(nn.Module):
  2. #def __init__(self):
  3.  
  4. def forward(self, classifications, regressions, anchors, annotations):
  5. alpha = 0.25
  6. gamma = 2.0
  7. batch_size = classifications.shape[0]
  8. classification_losses = []
  9. regression_losses = []
  10.  
  11. anchor = anchors[0, :, :]
  12.  
  13. anchor_widths = anchor[:, 2] - anchor[:, 0]
  14. anchor_heights = anchor[:, 3] - anchor[:, 1]
  15. anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
  16. anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights
  17.  
  18. for j in range(batch_size):
  19.  
  20. classification = classifications[j, :, :]
  21. regression = regressions[j, :, :]
  22.  
  23. bbox_annotation = annotations[j, :, :]
  24. bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
  25.  
  26. if bbox_annotation.shape[0] == 0:
  27. regression_losses.append(torch.tensor(0).float().cuda())
  28. classification_losses.append(torch.tensor(0).float().cuda())
  29.  
  30. continue
  31.  
  32. classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
  33.  
  34. IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations
  35.  
  36. IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1
  37.  
  38. #import pdb
  39. #pdb.set_trace()
  40.  
  41. # compute the loss for classification
  42. targets = torch.ones(classification.shape) * -1
  43. targets = targets.cuda()
  44.  
  45. targets[torch.lt(IoU_max, 0.4), :] = 0
  46.  
  47. positive_indices = torch.ge(IoU_max, 0.5)
  48.  
  49. num_positive_anchors = positive_indices.sum()
  50.  
  51. assigned_annotations = bbox_annotation[IoU_argmax, :]
  52.  
  53. targets[positive_indices, :] = 0
  54. targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
  55.  
  56. alpha_factor = torch.ones(targets.shape).cuda() * alpha
  57.  
  58. alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
  59. focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
  60. focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
  61.  
  62. bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
  63.  
  64. # cls_loss = focal_weight * torch.pow(bce, gamma)
  65. cls_loss = focal_weight * bce
  66.  
  67. cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
  68.  
  69. classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))
  70.  
  71. # compute the loss for regression
  72.  
  73. if positive_indices.sum() > 0:
  74. assigned_annotations = assigned_annotations[positive_indices, :]
  75.  
  76. anchor_widths_pi = anchor_widths[positive_indices]
  77. anchor_heights_pi = anchor_heights[positive_indices]
  78. anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
  79. anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
  80.  
  81. gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
  82. gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
  83. gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
  84. gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
  85.  
  86. # clip widths to 1
  87. gt_widths = torch.clamp(gt_widths, min=1)
  88. gt_heights = torch.clamp(gt_heights, min=1)
  89.  
  90. targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
  91. targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
  92. targets_dw = torch.log(gt_widths / anchor_widths_pi)
  93. targets_dh = torch.log(gt_heights / anchor_heights_pi)
  94.  
  95. targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
  96. targets = targets.t()
  97.  
  98. targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
  99.  
  100.  
  101. negative_indices = 1 - positive_indices
  102.  
  103. regression_diff = torch.abs(targets - regression[positive_indices, :])
  104.  
  105. regression_loss = torch.where(
  106. torch.le(regression_diff, 1.0 / 9.0),
  107. 0.5 * 9.0 * torch.pow(regression_diff, 2),
  108. regression_diff - 0.5 / 9.0
  109. )
  110. regression_losses.append(regression_loss.mean())
  111. else:
  112. regression_losses.append(torch.tensor(0).float().cuda())
  113.  
  114. return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement