Guest User

Untitled

a guest
Sep 11th, 2025
11
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.65 KB | None | 0 0
  1. import torch
  2. import torch.nn.functional as F
  3. from torch_frft.dfrft_module import dfrftmtx
  4.  
  5. # Глобальный кэш для матриц FRFT
  6. _frft_matrices_cache = {}
  7.  
  8. def get_frft_matrix(signal_length, order, device):
  9. """
  10. Возвращает матрицу FRFT для заданных параметров из кэша или вычисляет её.
  11. """
  12. key = (signal_length, order, device)
  13. if key not in _frft_matrices_cache:
  14. matrix = dfrftmtx(signal_length, order).to(device)
  15. _frft_matrices_cache[key] = matrix
  16. return _frft_matrices_cache[key]
  17.  
  18. def complex_mse_loss(x, y):
  19. """
  20. MSE для комплексных тензоров.
  21. """
  22. diff = x - y
  23. return torch.mean(diff.real ** 2 + diff.imag ** 2)
  24.  
  25. def amplitude_phase_loss(x, y, phase_weight=0.5):
  26. """
  27. Loss раздельно для амплитуды и фазы комплексного сигнала.
  28. """
  29. amp_loss = F.mse_loss(torch.abs(x), torch.abs(y))
  30. phase_loss = F.mse_loss(torch.angle(x), torch.angle(y))
  31. return amp_loss + phase_weight * phase_loss
  32.  
  33. def gradient_loss(pred, target):
  34. """
  35. Loss на градиентах для улучшения чёткости вывода.
  36. """
  37. pred_grad = torch.diff(pred, dim=-1)
  38. target_grad = torch.diff(target, dim=-1)
  39. return F.l1_loss(pred_grad, target_grad)
  40.  
  41. def multiscale_frft_loss(pred, target, orders=[0.3, 0.5, 0.7], phase_weight=0.5, spatial_weight=0.3, gradient_weight=0.1):
  42. """
  43. Улучшенный многомасштабный FRFT loss с регуляризацией.
  44.  
  45. Args:
  46. pred: Тензор предсказания модели
  47. target: Целевой тензор
  48. orders: Список порядков FRFT преобразования
  49. phase_weight: Вес фазовой компоненты loss
  50. spatial_weight: Вес пространственного MSE loss
  51. gradient_weight: Вес градиентной регуляризации
  52. """
  53. device = pred.device
  54. signal_length = pred.size(-1)
  55. assert pred.shape == target.shape, "Pred and target must have the same shape"
  56.  
  57. # Изменяем форму для поддержки многомерных данных
  58. original_shape = pred.shape
  59. if pred.dim() > 2:
  60. pred = pred.reshape(-1, signal_length)
  61. target = target.reshape(-1, signal_length)
  62.  
  63. # Преобразуем в комплексные числа
  64. pred_complex = pred.to(torch.complex64)
  65. target_complex = target.to(torch.complex64)
  66.  
  67. # Вычисляем компоненты loss
  68. frft_loss = 0.0
  69. spatial_loss = F.mse_loss(pred, target)
  70. grad_loss = gradient_loss(pred, target)
  71.  
  72. # Многомасштабное FRFT преобразование
  73. for order in orders:
  74. matrix = get_frft_matrix(signal_length, order, device)
  75. # Убеждаемся, что матрица тоже имеет тип complex64
  76. if not torch.is_complex(matrix):
  77. matrix = matrix.to(torch.complex64)
  78.  
  79. pred_frft = torch.matmul(matrix, pred_complex.unsqueeze(-1)).squeeze(-1)
  80. target_frft = torch.matmul(matrix, target_complex.unsqueeze(-1)).squeeze(-1)
  81. frft_loss += amplitude_phase_loss(pred_frft, target_frft, phase_weight)
  82.  
  83. frft_loss /= len(orders)
  84.  
  85. # Комбинируем все компоненты loss
  86. total_loss = (1 - spatial_weight) * frft_loss + \
  87. spatial_weight * spatial_loss + \
  88. gradient_weight * grad_loss
  89.  
  90. return total_loss
Advertisement
Add Comment
Please, Sign In to add comment