Advertisement
Guest User

Untitled

a guest
Dec 6th, 2018
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.74 KB | None | 0 0
  1. class VAE(nn.Module):
  2. def __init__(self):
  3. self.in_vae_dim = WORD_EMBEDDING_DIM
  4. self.out_vae_dim = CONDITION_DIM
  5. self.fc = nn.Linear(self.in_vae_dim, self.out_vae_dim * 2, bias = True)
  6. self.relu = nn.ReLU()
  7. def forward(self, word_embedding):
  8. #out: (BATCH_SIZE, COMDITION_DIM * 2)
  9. out = self.relu(self.fc(word_embedding))
  10. mu = out[:, :self.out_vae_dim]
  11. logvariance = out[:, self.out_vae_dim:]
  12. std = logvariance.mul(0.5).add(mu)
  13. sample_from_normal = torch.new_empty(mu.size(), requires_grad=True, device=device).normal_()
  14.  
  15. parametrized_output = sample_from_normal.mul(std).add(mu)
  16. return parametrized_output, mu, logvariance
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement