Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python2
- import torch
- import torch.utils.data
- import torch.nn as nn
- from torch.autograd import Variable
- from utils import log_stdnormal, log_normal, log_mean_exp
- def batchnormlayer(IN_DIM, OUT_DIM):
- return nn.Sequential(nn.Linear(IN_DIM, OUT_DIM), nn.BatchNorm1d(OUT_DIM))
- def normaldenselayer(IN_DIM, OUT_DIM):
- return nn.Linear(IN_DIM, OUT_DIM)
- class Model(nn.Module):
- def __init__(
- self,
- vocab_size=10000,
- latent_dim=50,
- hidden_dim=200,
- batchnorm=False,
- usecuda=True
- ):
- super(Model, self).__init__()
- # construct inference network
- self.vocab_size = vocab_size
- self.latent_dim = latent_dim
- self.batchnorm = batchnorm
- self.usecuda = usecuda
- if self.batchnorm:
- denselayer = batchnormlayer
- else:
- denselayer = normaldenselayer
- # ==== Network ===== #
- self.hidden_dim = hidden_dim
- self.fc1 = denselayer(self.vocab_size, self.hidden_dim)
- self.fc2 = denselayer(self.hidden_dim, self.hidden_dim)
- self.fc31 = denselayer(self.hidden_dim, self.latent_dim)
- self.fc32 = denselayer(self.hidden_dim, self.latent_dim)
- self.fc4 = denselayer(self.latent_dim, self.hidden_dim)
- self.fc5 = denselayer(self.hidden_dim, self.vocab_size)
- self.tanh = nn.Tanh()
- self.relu = nn.ReLU()
- self.sigmoid = nn.Sigmoid()
- self.softmax = nn.Softmax()
- ''' Inference Network, q(h|X) '''
- def encode(self, x): # inference network q(h|X)
- h1 = self.relu(self.fc1(x))
- h2 = self.relu(self.fc2(h1))
- return self.fc31(h2), self.fc32(h2)
- def reparametrize(self, mu, logvar, IW=1):
- std = logvar.mul(0.5).exp_()
- if self.usecuda:
- eps = torch.cuda.FloatTensor(IW, std.size(0), std.size(1)).normal_()
- else:
- eps = torch.FloatTensor(IW, std.size(0), std.size(1)).normal_()
- eps = Variable(eps)
- return eps.mul(std.expand_as(eps)).add_(mu.expand_as(eps))
- def decode(self, z):
- h4 = self.relu(self.fc4(z))
- return self.softmax(self.fc5(h4))
- def forward(self, x, IW=1):
- mu, logvar = self.encode(x.view(-1, self.vocab_size))
- z = self.reparametrize(mu, logvar, IW)
- return z, self.decode(z.view(-1, self.latent_dim)).view(
- IW, -1, self.vocab_size
- ), mu, logvar
- def components(self):
- Ik = Variable(torch.eye(self.latent_dim), requires_grad=False)
- if self.usecuda:
- Ik = Ik.cuda()
- comps = self.decode(Ik)
- return comps
- def nvdm_loss_function(recon_x, x, mu, logvar, temp=1):
- # temp is the temperature for annealing
- # Log-likelihood for words
- eps = 1e-9
- BCE = -(x * torch.log(recon_x.squeeze() + eps))
- BCE = BCE.sum()
- # see Appendix B from VAE paper:
- # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
- # https://arxiv.org/abs/1312.6114
- # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
- KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
- KLD = torch.sum(KLD_element).mul_(-0.5)
- return BCE + temp * KLD
- def sum_log_perplexity(recon_x, x, mu, logvar):
- # assumes x is of shape [1, input_dim]
- # 0.5 is assumed as the pseudo count for each doc
- return nvdm_loss_function(recon_x, x, mu, logvar, temp=1) / (x.sum() + 0.5)
- def nvdm_iw_loss(z, recon_x, x, mu, logvar):
- # z size: (iw, bs, zdim)
- # recon_x size: (iw, bs, xdim)
- # x size: (bs, xdim)
- # mu size: (bs, zdim)
- # logvar size: (bs, zdim)
- eps = 1e-9
- log_pz = log_stdnormal(z).sum(-1)
- log_px_given_z = (x.expand_as(recon_x) * torch.log(recon_x + eps)).sum(-1)
- log_qz_given_xr = log_normal(z, mu.expand_as(z),
- logvar.expand_as(z)).sum(-1)
- LL = log_mean_exp(log_pz + log_px_given_z - log_qz_given_xr, dim=0)
- return -torch.sum(LL)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement