Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class VAE(nn.Module):
- def __init__(self):
- self.in_vae_dim = WORD_EMBEDDING_DIM
- self.out_vae_dim = CONDITION_DIM
- self.fc = nn.Linear(self.in_vae_dim, self.out_vae_dim * 2, bias = True)
- self.relu = nn.ReLU()
- def forward(self, word_embedding):
- #out: (BATCH_SIZE, COMDITION_DIM * 2)
- out = self.relu(self.fc(word_embedding))
- mu = out[:, :self.out_vae_dim]
- logvariance = out[:, self.out_vae_dim:]
- std = logvariance.mul(0.5).add(mu)
- sample_from_normal = torch.new_empty(mu.size(), requires_grad=True, device=device).normal_()
- parametrized_output = sample_from_normal.mul(std).add(mu)
- return parametrized_output, mu, logvariance
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement