Guest User

Untitled

a guest
Sep 12th, 2025
28
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.22 KB | None | 0 0
  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. from pytorch_wavelets import DWTForward
  5. from torch_frft.dfrft_module import dfrft
  6.  
  7.  
  8. # FFT
  9. def mse_complex(x, y):
  10. diff = x - y
  11. return torch.mean(diff.real ** 2 + diff.imag ** 2)
  12.  
  13.  
  14. def fft_loss(model_pred, target, alphas=[1.0]):
  15. losses = []
  16. for a in alphas:
  17. losses.append(mse_complex(dfrft(model_pred.float(), a), dfrft(target.float(), a)))
  18. return sum(losses) / len(losses)
  19.  
  20.  
  21. # WVT
  22. def compute_weights(timestep, width, shift):
  23. timestep = min(1000, 1000 - timestep + shift)
  24. x = torch.linspace(0, 1000, 1000).to("cuda")
  25. func = lambda x, mu, sigma, amplitude: amplitude * torch.exp(-(x - mu) ** 2 / (2 * sigma ** 2))
  26. y = func(x, mu=timestep, sigma=width, amplitude=1.0)
  27. weights = [y[i].item() for i in [0, 200, 400, 600, 800, 999]]
  28. return weights
  29.  
  30.  
  31. def wavelet_loss(model_pred, target, timesteps):
  32. if timesteps is None:
  33. raise ValueError("Wavelet loss requires `timesteps`")
  34.  
  35. loss_levels = []
  36. for b in range(model_pred.shape[0]):
  37. pred = model_pred[b, ...].unsqueeze(0)
  38. tgt = target[b, ...].unsqueeze(0)
  39.  
  40. num_levels = math.ceil(math.log2(max(pred.shape[2], pred.shape[3])))
  41. dwt = DWTForward(J=num_levels, mode="zero", wave="haar").to(device=pred.device, dtype=torch.float)
  42.  
  43. model_pred_xl, model_pred_xh = dwt(pred.float())
  44. target_xl, target_xh = dwt(tgt.float())
  45.  
  46. model_pred_l0 = model_pred_xl.unsqueeze(2)
  47. target_l0 = target_xl.unsqueeze(2)
  48.  
  49. level_weights = compute_weights(timesteps[b], 350, 0)
  50.  
  51. for p, t, w in zip(model_pred_xh + [model_pred_l0], target_xh + [target_l0], range(1, num_levels + 1)):
  52. l = F.mse_loss(p.float(), t.float(), reduction="none")
  53. weight_idx = min(num_levels + 1 - w, len(level_weights) - 1)
  54. loss_levels.append(level_weights[weight_idx] * l.mean([2, 3, 4]))
  55.  
  56. return torch.stack(loss_levels, dim=0).mean()
  57.  
  58.  
  59. # COMBINE
  60. def combined_loss(model_pred, target, timesteps, alpha=0.5):
  61. loss_fft = fft_loss(model_pred, target)
  62. loss_wavelet = wavelet_loss(model_pred, target, timesteps)
  63. return alpha * loss_fft + (1 - alpha) * loss_wavelet
Advertisement
Add Comment
Please, Sign In to add comment