Guest User

Untitled

a guest
Sep 1st, 2025
12
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.52 KB | None | 0 0
  1. def conditional_loss(
  2. model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None, timesteps: Optional[torch.Tensor] = None
  3. ):
  4. """
  5. NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
  6. """
  7. if loss_type == "l2":
  8. loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
  9. elif loss_type == "fft":
  10. loss = mse_complex(dfrft(model_pred.float(), 0.5), dfrft(target.float(), 0.5))
Advertisement
Add Comment
Please, Sign In to add comment