Guest User

Untitled

a guest
Sep 30th, 2025
25
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.52 KB | None | 0 0
  1. # === FFT (FRFT) loss ===
  2. def fft_loss(model_pred, target, timesteps, alpha_min=0.0, alpha_max=2.0, eps=1e-6):
  3. """
  4. FRFT-Charbonnier loss с адаптивным alpha по timestep.
  5.  
  6. Args:
  7. model_pred: предсказание модели (тензор)
  8. target: целевое изображение (тензор)
  9. timesteps: тензор текущих шагов диффузии
  10. alpha_min: минимальный угол FRFT
  11. alpha_max: максимальный угол FRFT
  12. eps: числовой стабилизатор для Charbonnier
  13. Returns:
  14. усреднённый loss
  15. """
  16. # нормализуем timesteps к [0,1] и масштабируем в диапазон [alpha_min, alpha_max]
  17. t_norm = timesteps.float() / timesteps.max()
  18. # alpha для каждого примера в батче
  19. alphas = alpha_min + t_norm * (alpha_max - alpha_min)
  20.  
  21. # считаем DFRFT по каждому alpha в батче
  22. losses = []
  23. for i, a in enumerate(alphas):
  24. pred_frft = dfrft(model_pred[i:i+1].float(), a.item())
  25. tgt_frft = dfrft(target[i:i+1].float(), a.item())
  26. diff = pred_frft - tgt_frft
  27. loss = torch.sqrt(diff.real**2 + diff.imag**2 + eps**2)
  28. losses.append(loss.mean())
  29.  
  30. # усредняем loss по батчу
  31. return torch.stack(losses).mean()
  32.  
  33.  
  34. Loss функция:
  35. elif loss_type == "fft":
  36. loss = fft_loss(model_pred, target, timesteps)
Advertisement
Add Comment
Please, Sign In to add comment