Advertisement
Guest User

Untitled

a guest
Dec 11th, 2018
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.18 KB | None | 0 0
  1. def log_likelihood(generated_set, validation_set, test_set):
  2.     """
  3.    Возвращает оценку логарифма правдоподобия модели GAN методом
  4.    Парзеновского окна со стандартным нормальным ядром.
  5.    Подсказка: sigma должна настраиваться по валидационной выборке, а
  6.    правдоподобие считаться по тестовой.
  7.    Подсказка: вместо sigma можно настраивать log_sigma.
  8.    Подсказка: для настойки sigma допустимо использовать как перебор по сетке,
  9.    так и другие методы опимизации.
  10.    Вход: generated_set - сэмплы из генеративной модели.
  11.    Вход: validation_set - валидационная выборка.
  12.    Вход: test_set - тестовая выборка.
  13.    Возвращаемое значение: float (не Tensor!) - оценка логарифма правдоподобия.
  14.    """
  15.     # ваш код здесь
  16.     generated_set = generated_set.view(generated_set.shape[0], digit_size*digit_size)
  17.     print(generated_set.shape, validation_set.shape, test_set.shape)
  18.     D = generated_set.shape[1]
  19.     const = -D*math.log(2 * math.pi)/2
  20.     log_sq_sigma = torch.randn(1, device=device, requires_grad=True)
  21.     nt = LBFGS([log_sq_sigma])
  22.  
  23.     def compute_log_density(second_set):
  24.         a = (generated_set**2).sum(dim=1)[:, None]
  25.         b = (second_set**2).sum(dim=1)[None, :]
  26.         c = generated_set@second_set.transpose(0, 1)
  27.         quad = -(a + b - 2*c)/(2*log_sq_sigma.exp())
  28.         log_quad = log_mean_exp(quad).mean()
  29.         log_det = -D*log_sq_sigma/2
  30.         return log_det + log_quad
  31.        
  32.     def closure():
  33.         nt.zero_grad()
  34.         neg_log_density = -compute_log_density(validation_set)
  35.         neg_log_density.backward()
  36.         print('log_sq_sigma: ', log_sq_sigma)
  37.         return neg_log_density
  38.    
  39.     nt.step(closure)
  40.     log_density = compute_log_density(test_set)
  41.     return const + log_density.detach().item()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement