Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def log_likelihood(generated_set, validation_set, test_set):
- """
- Возвращает оценку логарифма правдоподобия модели GAN методом
- Парзеновского окна со стандартным нормальным ядром.
- Подсказка: sigma должна настраиваться по валидационной выборке, а
- правдоподобие считаться по тестовой.
- Подсказка: вместо sigma можно настраивать log_sigma.
- Подсказка: для настойки sigma допустимо использовать как перебор по сетке,
- так и другие методы опимизации.
- Вход: generated_set - сэмплы из генеративной модели.
- Вход: validation_set - валидационная выборка.
- Вход: test_set - тестовая выборка.
- Возвращаемое значение: float (не Tensor!) - оценка логарифма правдоподобия.
- """
- # ваш код здесь
- generated_set = generated_set.view(generated_set.shape[0], digit_size*digit_size)
- print(generated_set.shape, validation_set.shape, test_set.shape)
- D = generated_set.shape[1]
- const = -D*math.log(2 * math.pi)/2
- log_sq_sigma = torch.randn(1, device=device, requires_grad=True)
- nt = LBFGS([log_sq_sigma])
- def compute_log_density(second_set):
- a = (generated_set**2).sum(dim=1)[:, None]
- b = (second_set**2).sum(dim=1)[None, :]
- c = generated_set@second_set.transpose(0, 1)
- quad = -(a + b - 2*c)/(2*log_sq_sigma.exp())
- log_quad = log_mean_exp(quad).mean()
- log_det = -D*log_sq_sigma/2
- return log_det + log_quad
- def closure():
- nt.zero_grad()
- neg_log_density = -compute_log_density(validation_set)
- neg_log_density.backward()
- print('log_sq_sigma: ', log_sq_sigma)
- return neg_log_density
- nt.step(closure)
- log_density = compute_log_density(test_set)
- return const + log_density.detach().item()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement