Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn.functional as F
- from torch_frft.dfrft_module import dfrftmtx
- # Глобальный кэш для матриц FRFT
- _frft_matrices_cache = {}
- def get_frft_matrix(signal_length, order, device):
- """
- Возвращает матрицу FRFT для заданных параметров из кэша или вычисляет её.
- """
- key = (signal_length, order, device)
- if key not in _frft_matrices_cache:
- matrix = dfrftmtx(signal_length, order).to(device)
- _frft_matrices_cache[key] = matrix
- return _frft_matrices_cache[key]
- def complex_mse_loss(x, y):
- """
- MSE для комплексных тензоров.
- """
- diff = x - y
- return torch.mean(diff.real ** 2 + diff.imag ** 2)
- def amplitude_phase_loss(x, y, phase_weight=0.5):
- """
- Loss раздельно для амплитуды и фазы комплексного сигнала.
- """
- amp_loss = F.mse_loss(torch.abs(x), torch.abs(y))
- phase_loss = F.mse_loss(torch.angle(x), torch.angle(y))
- return amp_loss + phase_weight * phase_loss
- def gradient_loss(pred, target):
- """
- Loss на градиентах для улучшения чёткости вывода.
- """
- pred_grad = torch.diff(pred, dim=-1)
- target_grad = torch.diff(target, dim=-1)
- return F.l1_loss(pred_grad, target_grad)
- def multiscale_frft_loss(pred, target, orders=[0.3, 0.5, 0.7], phase_weight=0.5, spatial_weight=0.3, gradient_weight=0.1):
- """
- Улучшенный многомасштабный FRFT loss с регуляризацией.
- Args:
- pred: Тензор предсказания модели
- target: Целевой тензор
- orders: Список порядков FRFT преобразования
- phase_weight: Вес фазовой компоненты loss
- spatial_weight: Вес пространственного MSE loss
- gradient_weight: Вес градиентной регуляризации
- """
- device = pred.device
- signal_length = pred.size(-1)
- assert pred.shape == target.shape, "Pred and target must have the same shape"
- # Изменяем форму для поддержки многомерных данных
- original_shape = pred.shape
- if pred.dim() > 2:
- pred = pred.reshape(-1, signal_length)
- target = target.reshape(-1, signal_length)
- # Преобразуем в комплексные числа
- pred_complex = pred.to(torch.complex64)
- target_complex = target.to(torch.complex64)
- # Вычисляем компоненты loss
- frft_loss = 0.0
- spatial_loss = F.mse_loss(pred, target)
- grad_loss = gradient_loss(pred, target)
- # Многомасштабное FRFT преобразование
- for order in orders:
- matrix = get_frft_matrix(signal_length, order, device)
- # Убеждаемся, что матрица тоже имеет тип complex64
- if not torch.is_complex(matrix):
- matrix = matrix.to(torch.complex64)
- pred_frft = torch.matmul(matrix, pred_complex.unsqueeze(-1)).squeeze(-1)
- target_frft = torch.matmul(matrix, target_complex.unsqueeze(-1)).squeeze(-1)
- frft_loss += amplitude_phase_loss(pred_frft, target_frft, phase_weight)
- frft_loss /= len(orders)
- # Комбинируем все компоненты loss
- total_loss = (1 - spatial_weight) * frft_loss + \
- spatial_weight * spatial_loss + \
- gradient_weight * grad_loss
- return total_loss
Advertisement
Add Comment
Please, Sign In to add comment