Advertisement
Guest User

Untitled

a guest
Jun 25th, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.61 KB | None | 0 0
  1. # KL
  2. log_p_z = log_Normal_standard(z_q, dim=1)
  3. log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
  4. KL = -(log_p_z - log_q_z)
  5.  
  6. def log_Normal_diag(x, mean, log_var, average=False, dim=None):
  7. log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
  8. if average:
  9. return torch.mean( log_normal, dim )
  10. else:
  11. return torch.sum( log_normal, dim )
  12.  
  13. def log_Normal_standard(x, average=False, dim=None):
  14. log_normal = -0.5 * torch.pow( x , 2 )
  15. if average:
  16. return torch.mean( log_normal, dim )
  17. else:
  18. return torch.sum( log_normal, dim )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement