Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # KL
- log_p_z = log_Normal_standard(z_q, dim=1)
- log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
- KL = -(log_p_z - log_q_z)
- def log_Normal_diag(x, mean, log_var, average=False, dim=None):
- log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
- if average:
- return torch.mean( log_normal, dim )
- else:
- return torch.sum( log_normal, dim )
- def log_Normal_standard(x, average=False, dim=None):
- log_normal = -0.5 * torch.pow( x , 2 )
- if average:
- return torch.mean( log_normal, dim )
- else:
- return torch.sum( log_normal, dim )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement