Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def conditional_loss(
- model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None, timesteps: Optional[torch.Tensor] = None
- ):
- """
- NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
- """
- if loss_type == "l2":
- loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
- elif loss_type == "fft":
- loss = mse_complex(dfrft(model_pred.float(), 0.5), dfrft(target.float(), 0.5))
Advertisement
Add Comment
Please, Sign In to add comment