Advertisement
lskeeper

Untitled

Oct 17th, 2017
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.01 KB | None | 0 0
  1. #!/usr/bin/env python2
  2. import torch
  3. import torch.utils.data
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from utils import log_stdnormal, log_normal, log_mean_exp
  7.  
  8.  
  9. def batchnormlayer(IN_DIM, OUT_DIM):
  10. return nn.Sequential(nn.Linear(IN_DIM, OUT_DIM), nn.BatchNorm1d(OUT_DIM))
  11.  
  12.  
  13. def normaldenselayer(IN_DIM, OUT_DIM):
  14. return nn.Linear(IN_DIM, OUT_DIM)
  15.  
  16.  
  17. class Model(nn.Module):
  18. def __init__(
  19. self,
  20. vocab_size=10000,
  21. latent_dim=50,
  22. hidden_dim=200,
  23. batchnorm=False,
  24. usecuda=True
  25. ):
  26. super(Model, self).__init__()
  27.  
  28. # construct inference network
  29. self.vocab_size = vocab_size
  30. self.latent_dim = latent_dim
  31.  
  32. self.batchnorm = batchnorm
  33. self.usecuda = usecuda
  34.  
  35. if self.batchnorm:
  36. denselayer = batchnormlayer
  37. else:
  38. denselayer = normaldenselayer
  39.  
  40. # ==== Network ===== #
  41. self.hidden_dim = hidden_dim
  42.  
  43. self.fc1 = denselayer(self.vocab_size, self.hidden_dim)
  44. self.fc2 = denselayer(self.hidden_dim, self.hidden_dim)
  45. self.fc31 = denselayer(self.hidden_dim, self.latent_dim)
  46. self.fc32 = denselayer(self.hidden_dim, self.latent_dim)
  47.  
  48. self.fc4 = denselayer(self.latent_dim, self.hidden_dim)
  49. self.fc5 = denselayer(self.hidden_dim, self.vocab_size)
  50.  
  51. self.tanh = nn.Tanh()
  52. self.relu = nn.ReLU()
  53. self.sigmoid = nn.Sigmoid()
  54. self.softmax = nn.Softmax()
  55.  
  56. ''' Inference Network, q(h|X) '''
  57.  
  58. def encode(self, x): # inference network q(h|X)
  59. h1 = self.relu(self.fc1(x))
  60. h2 = self.relu(self.fc2(h1))
  61. return self.fc31(h2), self.fc32(h2)
  62.  
  63. def reparametrize(self, mu, logvar, IW=1):
  64. std = logvar.mul(0.5).exp_()
  65. if self.usecuda:
  66. eps = torch.cuda.FloatTensor(IW, std.size(0), std.size(1)).normal_()
  67. else:
  68. eps = torch.FloatTensor(IW, std.size(0), std.size(1)).normal_()
  69. eps = Variable(eps)
  70. return eps.mul(std.expand_as(eps)).add_(mu.expand_as(eps))
  71.  
  72. def decode(self, z):
  73. h4 = self.relu(self.fc4(z))
  74. return self.softmax(self.fc5(h4))
  75.  
  76. def forward(self, x, IW=1):
  77. mu, logvar = self.encode(x.view(-1, self.vocab_size))
  78. z = self.reparametrize(mu, logvar, IW)
  79. return z, self.decode(z.view(-1, self.latent_dim)).view(
  80. IW, -1, self.vocab_size
  81. ), mu, logvar
  82.  
  83. def components(self):
  84. Ik = Variable(torch.eye(self.latent_dim), requires_grad=False)
  85. if self.usecuda:
  86. Ik = Ik.cuda()
  87.  
  88. comps = self.decode(Ik)
  89. return comps
  90.  
  91.  
  92. def nvdm_loss_function(recon_x, x, mu, logvar, temp=1):
  93. # temp is the temperature for annealing
  94. # Log-likelihood for words
  95. eps = 1e-9
  96. BCE = -(x * torch.log(recon_x.squeeze() + eps))
  97. BCE = BCE.sum()
  98.  
  99. # see Appendix B from VAE paper:
  100. # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
  101. # https://arxiv.org/abs/1312.6114
  102. # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
  103. KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
  104. KLD = torch.sum(KLD_element).mul_(-0.5)
  105.  
  106. return BCE + temp * KLD
  107.  
  108.  
  109. def sum_log_perplexity(recon_x, x, mu, logvar):
  110. # assumes x is of shape [1, input_dim]
  111. # 0.5 is assumed as the pseudo count for each doc
  112. return nvdm_loss_function(recon_x, x, mu, logvar, temp=1) / (x.sum() + 0.5)
  113.  
  114.  
  115. def nvdm_iw_loss(z, recon_x, x, mu, logvar):
  116. # z size: (iw, bs, zdim)
  117. # recon_x size: (iw, bs, xdim)
  118. # x size: (bs, xdim)
  119. # mu size: (bs, zdim)
  120. # logvar size: (bs, zdim)
  121. eps = 1e-9
  122. log_pz = log_stdnormal(z).sum(-1)
  123. log_px_given_z = (x.expand_as(recon_x) * torch.log(recon_x + eps)).sum(-1)
  124. log_qz_given_xr = log_normal(z, mu.expand_as(z),
  125. logvar.expand_as(z)).sum(-1)
  126.  
  127. LL = log_mean_exp(log_pz + log_px_given_z - log_qz_given_xr, dim=0)
  128. return -torch.sum(LL)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement