SHARE
TWEET

Untitled

a guest Jun 25th, 2019 60 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # KL
  2. log_p_z = log_Normal_standard(z_q, dim=1)
  3. log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
  4. KL = -(log_p_z - log_q_z)
  5.      
  6. def log_Normal_diag(x, mean, log_var, average=False, dim=None):
  7.     log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
  8.     if average:
  9.         return torch.mean( log_normal, dim )
  10.     else:
  11.         return torch.sum( log_normal, dim )
  12.  
  13. def log_Normal_standard(x, average=False, dim=None):
  14.     log_normal = -0.5 * torch.pow( x , 2 )
  15.     if average:
  16.         return torch.mean( log_normal, dim )
  17.     else:
  18.         return torch.sum( log_normal, dim )
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top