Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # === FFT (FRFT) loss ===
- def fft_loss(model_pred, target, timesteps, alpha_min=0.0, alpha_max=2.0, eps=1e-6):
- """
- FRFT-Charbonnier loss с адаптивным alpha по timestep.
- Args:
- model_pred: предсказание модели (тензор)
- target: целевое изображение (тензор)
- timesteps: тензор текущих шагов диффузии
- alpha_min: минимальный угол FRFT
- alpha_max: максимальный угол FRFT
- eps: числовой стабилизатор для Charbonnier
- Returns:
- усреднённый loss
- """
- # нормализуем timesteps к [0,1] и масштабируем в диапазон [alpha_min, alpha_max]
- t_norm = timesteps.float() / timesteps.max()
- # alpha для каждого примера в батче
- alphas = alpha_min + t_norm * (alpha_max - alpha_min)
- # считаем DFRFT по каждому alpha в батче
- losses = []
- for i, a in enumerate(alphas):
- pred_frft = dfrft(model_pred[i:i+1].float(), a.item())
- tgt_frft = dfrft(target[i:i+1].float(), a.item())
- diff = pred_frft - tgt_frft
- loss = torch.sqrt(diff.real**2 + diff.imag**2 + eps**2)
- losses.append(loss.mean())
- # усредняем loss по батчу
- return torch.stack(losses).mean()
- Loss функция:
- elif loss_type == "fft":
- loss = fft_loss(model_pred, target, timesteps)
Advertisement
Add Comment
Please, Sign In to add comment