Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import math
- import torch
- import torch.nn.functional as F
- from pytorch_wavelets import DWTForward
- from torch_frft.dfrft_module import dfrft
- # FFT
- def mse_complex(x, y):
- diff = x - y
- return torch.mean(diff.real ** 2 + diff.imag ** 2)
- def fft_loss(model_pred, target, alphas=[1.0]):
- losses = []
- for a in alphas:
- losses.append(mse_complex(dfrft(model_pred.float(), a), dfrft(target.float(), a)))
- return sum(losses) / len(losses)
- # WVT
- def compute_weights(timestep, width, shift):
- timestep = min(1000, 1000 - timestep + shift)
- x = torch.linspace(0, 1000, 1000).to("cuda")
- func = lambda x, mu, sigma, amplitude: amplitude * torch.exp(-(x - mu) ** 2 / (2 * sigma ** 2))
- y = func(x, mu=timestep, sigma=width, amplitude=1.0)
- weights = [y[i].item() for i in [0, 200, 400, 600, 800, 999]]
- return weights
- def wavelet_loss(model_pred, target, timesteps):
- if timesteps is None:
- raise ValueError("Wavelet loss requires `timesteps`")
- loss_levels = []
- for b in range(model_pred.shape[0]):
- pred = model_pred[b, ...].unsqueeze(0)
- tgt = target[b, ...].unsqueeze(0)
- num_levels = math.ceil(math.log2(max(pred.shape[2], pred.shape[3])))
- dwt = DWTForward(J=num_levels, mode="zero", wave="haar").to(device=pred.device, dtype=torch.float)
- model_pred_xl, model_pred_xh = dwt(pred.float())
- target_xl, target_xh = dwt(tgt.float())
- model_pred_l0 = model_pred_xl.unsqueeze(2)
- target_l0 = target_xl.unsqueeze(2)
- level_weights = compute_weights(timesteps[b], 350, 0)
- for p, t, w in zip(model_pred_xh + [model_pred_l0], target_xh + [target_l0], range(1, num_levels + 1)):
- l = F.mse_loss(p.float(), t.float(), reduction="none")
- weight_idx = min(num_levels + 1 - w, len(level_weights) - 1)
- loss_levels.append(level_weights[weight_idx] * l.mean([2, 3, 4]))
- return torch.stack(loss_levels, dim=0).mean()
- # COMBINE
- def combined_loss(model_pred, target, timesteps, alpha=0.5):
- loss_fft = fft_loss(model_pred, target)
- loss_wavelet = wavelet_loss(model_pred, target, timesteps)
- return alpha * loss_fft + (1 - alpha) * loss_wavelet
Advertisement
Add Comment
Please, Sign In to add comment