Advertisement
Mbxvim

pasta

Sep 23rd, 2022 (edited)
942
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.58 KB | None | 0 0
  1. import math
  2. import numpy as np
  3. from IPython.display import clear_output
  4. from tqdm import tqdm_notebook as tqdm
  5.  
  6. import matplotlib as mpl
  7. import matplotlib.pyplot as plt
  8. %matplotlib inline
  9. import seaborn as sns
  10. sns.color_palette("bright")
  11. import matplotlib as mpl
  12. import matplotlib.cm as cm
  13.  
  14. import torch
  15. from torch import Tensor
  16. from torch import nn
  17. from torch.nn  import functional as F
  18. from torch.autograd import Variable
  19.  
  20. use_cuda = torch.cuda.is_available()
  21. def ode_solve(z0, t0, t1, f):
  22.     """
  23.    Простейший метод эволюции ОДУ - метод Эйлера
  24.    """
  25.     h_max = 0.05
  26.     n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())
  27.  
  28.     h = (t1 - t0)/n_steps
  29.     t = t0
  30.     z = z0
  31.  
  32.     for i_step in range(n_steps):
  33.         z = z + h * f(z, t)
  34.         t = t + h
  35.     return z
  36. class ODEF(nn.Module):
  37.     def forward_with_grad(self, z, t, grad_outputs):
  38.         """Compute f and a df/dz, a df/dp, a df/dt"""
  39.         batch_size = z.shape[0]
  40.  
  41.         out = self.forward(z, t)
  42.  
  43.         a = grad_outputs
  44.         adfdz, adfdt, *adfdp = torch.autograd.grad(
  45.             (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a),
  46.             allow_unused=True, retain_graph=True
  47.         )
  48.         # метод grad автоматически суммирует градиенты для всех элементов батча,
  49.         # надо expand их обратно
  50.         if adfdp is not None:
  51.             adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
  52.             adfdp = adfdp.expand(batch_size, -1) / batch_size
  53.         if adfdt is not None:
  54.             adfdt = adfdt.expand(batch_size, 1) / batch_size
  55.         return out, adfdz, adfdt, adfdp
  56.  
  57.     def flatten_parameters(self):
  58.         p_shapes = []
  59.         flat_parameters = []
  60.         for p in self.parameters():
  61.             p_shapes.append(p.size())
  62.             flat_parameters.append(p.flatten())
  63.         return torch.cat(flat_parameters)
  64. class ODEAdjoint(torch.autograd.Function):
  65.     @staticmethod
  66.     def forward(ctx, z0, t, flat_parameters, func):
  67.         assert isinstance(func, ODEF)
  68.         bs, *z_shape = z0.size()
  69.         time_len = t.size(0)
  70.  
  71.         with torch.no_grad():
  72.             z = torch.zeros(time_len, bs, *z_shape).to(z0)
  73.             z[0] = z0
  74.             for i_t in range(time_len - 1):
  75.                 z0 = ode_solve(z0, t[i_t], t[i_t+1], func)
  76.                 z[i_t+1] = z0
  77.  
  78.         ctx.func = func
  79.         ctx.save_for_backward(t, z.clone(), flat_parameters)
  80.         return z
  81.  
  82.     @staticmethod
  83.     def backward(ctx, dLdz):
  84.         """
  85.        dLdz shape: time_len, batch_size, *z_shape
  86.        """
  87.         func = ctx.func
  88.         t, z, flat_parameters = ctx.saved_tensors
  89.         time_len, bs, *z_shape = z.size()
  90.         n_dim = np.prod(z_shape)
  91.         n_params = flat_parameters.size(0)
  92.  
  93.         # Динамика аугментированной системы,
  94.         # которую надо эволюционировать обратно во времени
  95.         def augmented_dynamics(aug_z_i, t_i):
  96.             """
  97.            Тензоры здесь - это срезы по времени
  98.            t_i - тензор с размерами: bs, 1
  99.            aug_z_i - тензор с размерами: bs, n_dim*2 + n_params + 1
  100.            """
  101.             # игнорируем параметры и время
  102.             z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  
  103.             # Unflatten z and a
  104.             z_i = z_i.view(bs, *z_shape)
  105.             a = a.view(bs, *z_shape)
  106.             with torch.set_grad_enabled(True):
  107.                 t_i = t_i.detach().requires_grad_(True)
  108.                 z_i = z_i.detach().requires_grad_(True)
  109.  
  110.                 faug = func.forward_with_grad(z_i, t_i, grad_outputs=a)
  111.                 func_eval, adfdz, adfdt, adfdp = faug
  112.  
  113.                 adfdz = adfdz if adfdz is not None else torch.zeros(bs, *z_shape)
  114.                 adfdp = adfdp if adfdp is not None else torch.zeros(bs, n_params)
  115.                 adfdt = adfdt if adfdt is not None else torch.zeros(bs, 1)
  116.                 adfdz = adfdz.to(z_i)
  117.                 adfdp = adfdp.to(z_i)
  118.                 adfdt = adfdt.to(z_i)
  119.  
  120.             # Flatten f and adfdz
  121.             func_eval = func_eval.view(bs, n_dim)
  122.             adfdz = adfdz.view(bs, n_dim)
  123.             return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)
  124.  
  125.         dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz для удобства
  126.         with torch.no_grad():
  127.             ## Создадим плейсхолдеры для возвращаемых градиентов
  128.             # Распространенные назад сопряженные состояния,
  129.             # которые надо поправить градиентами от наблюдений
  130.             adj_z = torch.zeros(bs, n_dim).to(dLdz)
  131.             adj_p = torch.zeros(bs, n_params).to(dLdz)
  132.             # В отличие от z и p, нужно вернуть градиенты для всех моментов времени
  133.             adj_t = torch.zeros(time_len, bs, 1).to(dLdz)
  134.  
  135.             for i_t in range(time_len-1, 0, -1):
  136.                 z_i = z[i_t]
  137.                 t_i = t[i_t]
  138.                 f_i = func(z_i, t_i).view(bs, n_dim)
  139.  
  140.                 # Рассчитаем прямые градиенты от наблюдений
  141.                 dLdz_i = dLdz[i_t]
  142.                 dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2),
  143.                                    f_i.unsqueeze(-1))[:, 0]
  144.  
  145.                 # Подправим ими сопряженные состояния
  146.                 adj_z += dLdz_i
  147.                 adj_t[i_t] = adj_t[i_t] - dLdt_i
  148.  
  149.                 # Упакуем аугментированные переменные в вектор
  150.                 aug_z = torch.cat((
  151.                     z_i.view(bs, n_dim),
  152.                     adj_z, torch.zeros(bs, n_params).to(z)
  153.                     adj_t[i_t]),
  154.                     dim=-1
  155.                 )
  156.  
  157.                 # Решим (эволюционируем) аугментированную систему назад во времени
  158.                 aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)
  159.  
  160.                 # Распакуем переменные обратно из решенной системы
  161.                 adj_z[:] = aug_ans[:, n_dim:2*n_dim]
  162.                 adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
  163.                 adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]
  164.  
  165.                 del aug_z, aug_ans
  166.  
  167.             ## Подправим сопряженное состояние в нулевой момент прямыми градиентами
  168.             # Вычислим прямые градиенты
  169.             dLdz_0 = dLdz[0]
  170.             dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2),
  171.                                 f_i.unsqueeze(-1))[:, 0]
  172.  
  173.             # Подправим
  174.             adj_z += dLdz_0
  175.             adj_t[0] = adj_t[0] - dLdt_0
  176.         return adj_z.view(bs, *z_shape), adj_t, adj_p, None
  177. class NeuralODE(nn.Module):
  178.     def __init__(self, func):
  179.         super(NeuralODE, self).__init__()
  180.         assert isinstance(func, ODEF)
  181.         self.func = func
  182.  
  183.     def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False):
  184.         t = t.to(z0)
  185.         z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)
  186.         if return_whole_sequence:
  187.             return z
  188.         else:
  189.             return z[-1]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement