Advertisement
Guest User

Untitled

a guest
Aug 23rd, 2019
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.63 KB | None | 0 0
  1. import torch
  2. from torch import nn
  3.  
  4. class LinearGaussian(nn.Module):
  5.     def __init__(self, in_features, out_features, certain=False,
  6.                  deterministic=True):
  7.         """
  8.        Applies linear transformation y = xA^T + b
  9.  
  10.        A and b are Gaussian random variables
  11.  
  12.        :param in_features: input dimension
  13.        :param out_features: output dimension
  14.        :param certain:  if true, than x is equal to its mean and has no variance
  15.        """
  16.  
  17.         super().__init__()
  18.  
  19.         self.in_features = in_features
  20.         self.out_features = out_features
  21.  
  22.         self.W = nn.Parameter(torch.Tensor(in_features, out_features))
  23.         self.bias = nn.Parameter(torch.Tensor(out_features))
  24.  
  25.         self.W_logvar = nn.Parameter(torch.Tensor(in_features, out_features))
  26.         self.bias_logvar = nn.Parameter(torch.Tensor(out_features))
  27.  
  28.         self.__initialize_weights()
  29.         self.__construct_priors()
  30.  
  31.         self.certain = certain
  32.         self.deterministic = deterministic
  33.         self.mean_forward = False
  34.         self.zero_mean = False
  35.  
  36.     def __initialize_weights(self):
  37.         nn.init.xavier_normal_(self.W)
  38.         nn.init.normal_(self.bias)
  39.  
  40.         nn.init.uniform_(self.W_logvar, a=-10, b=-7)
  41.         nn.init.uniform_(self.bias_logvar, a=-10, b=-7)
  42.  
  43.     def __construct_priors(self):
  44.         self.W_mean_prior = nn.Parameter(torch.zeros_like(self.W),
  45.                                          requires_grad=False)
  46.         self.W_var_prior = nn.Parameter(torch.ones_like(self.W_logvar) * 0.1,
  47.                                 requires_grad=False)
  48.  
  49.         self.bias_mean_prior = nn.Parameter(torch.zeros_like(self.bias),
  50.                                             requires_grad=False)
  51.         self.bias_var_prior = nn.Parameter(
  52.             torch.ones_like(self.bias_logvar) * 0.1,
  53.                                 requires_grad=False)
  54.  
  55.     def compute_kl(self):
  56.         weights_kl = kl_gaussian(self.W, torch.exp(self.W_logvar),
  57.                                  self.W_mean_prior, self.W_var_prior)
  58.         bias_kl = kl_gaussian(self.bias, torch.exp(self.bias_logvar),
  59.                               self.bias_mean_prior, self.bias_var_prior)
  60.         return weights_kl + bias_kl
  61.  
  62.     def set_flag(self, flag_name, value):
  63.         setattr(self, flag_name, value)
  64.         for m in self.children():
  65.             if hasattr(m, 'set_flag'):
  66.                 m.set_flag(flag_name, value)
  67.  
  68.     def forward(self, x):
  69.         """
  70.        Compute expectation and variance after linear transform
  71.        y = xA^T + b
  72.  
  73.        :param x: input, size [batch, in_features]
  74.        :return: tuple (y_mean, y_var) for deterministic mode:,  shapes:
  75.                 y_mean: [batch, out_features]
  76.                 y_var:  [batch, out_features, out_features]
  77.  
  78.                 tuple (sample, None) for MCVI mode,
  79.                 sample : [batch, out_features] - local reparametrization of output
  80.        """
  81.         x = self.__apply_activation(x)
  82.         if self.zero_mean:
  83.             return self.__zero_mean_forward(x)
  84.         elif self.mean_forward:
  85.             return self.__mean_forward(x)
  86.         elif self.deterministic:
  87.             return self.__det_forward(x)
  88.         else:
  89.             return self.__mcvi_forward(x)
  90.  
  91.     def __mcvi_forward(self, x):
  92.         W_var = torch.exp(self.W_logvar)
  93.         bias_var = torch.exp(self.bias_logvar)
  94.  
  95.         if self.certain:
  96.             x_mean = x
  97.             x_var = None
  98.         else:
  99.             x_mean = x[0]
  100.             x_var = x[1]
  101.  
  102.         y_mean = F.linear(x_mean, self.W.t()) + self.bias
  103.  
  104.         if self.certain or not self.deterministic:
  105.             xx = x_mean * x_mean
  106.             y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var)
  107.         else:
  108.             y_var = compute_linear_var(x_mean, x_var, self.W, W_var, self.bias,
  109.                                        bias_var)
  110.  
  111.         dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var)
  112.         sample = dst.rsample()
  113.         return sample, None
  114.  
  115.     def __det_forward(self, x):
  116.         W_var = torch.exp(self.W_logvar)
  117.         bias_var = torch.exp(self.bias_logvar)
  118.  
  119.         if self.certain:
  120.             x_mean = x
  121.             x_var = None
  122.         else:
  123.             x_mean = x[0]
  124.             x_var = x[1]
  125.  
  126.         y_mean = F.linear(x_mean, self.W.t()) + self.bias
  127.  
  128.         if self.certain:
  129.             xx = x_mean * x_mean
  130.             y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var)
  131.         else:
  132.             y_var = compute_linear_var(x_mean, x_var, self.W, W_var, self.bias,
  133.                                        bias_var)
  134.  
  135.         return y_mean, y_var
  136.  
  137.     def __mean_forward(self, x):
  138.         if not isinstance(x, tuple):
  139.             x_mean = x
  140.         else:
  141.             x_mean = x[0]
  142.  
  143.         y_mean = F.linear(x_mean, self.W.t()) + self.bias
  144.         return y_mean, None
  145.  
  146.     def __zero_mean_forward(self, x):
  147.         if not isinstance(x, tuple):
  148.             x_mean = x
  149.             x_var = None
  150.         else:
  151.             x_mean = x[0]
  152.             x_var = x[1]
  153.  
  154.         y_mean = F.linear(x_mean, torch.zeros_like(self.W).t()) + self.bias
  155.  
  156.         W_var = torch.exp(self.W_logvar)
  157.         bias_var = torch.exp(self.bias_logvar)
  158.  
  159.         if x_var is None:
  160.             xx = x_mean * x_mean
  161.             y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var)
  162.         else:
  163.             y_var = compute_linear_var(x_mean, x_var, torch.zeros_like(self.W),
  164.                                        W_var, self.bias, bias_var)
  165.  
  166.         if self.deterministic:
  167.             return y_mean, y_var
  168.         else:
  169.             dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var)
  170.             sample = dst.rsample()
  171.             return sample, None
  172.  
  173.     def __apply_activation(self, x):
  174.         return x
  175.  
  176.     def __repr__(self):
  177.         return self.__class__.__name__ + '(' \
  178.                + 'in_features=' + str(self.in_features) \
  179.                + ', out_features=' + str(self.out_features) + ')'
  180.  
  181.  
  182. class ReluGaussian(LinearGaussian):
  183.     def __apply_activation(self, x):
  184.         print("i am in Relu")
  185.         x_mean = x[0]
  186.         x_var = x[1]
  187.  
  188.         if not self.deterministic:
  189.             z_mean = F.relu(x_mean)
  190.             z_var = None
  191.         else:
  192.             x_var_diag = matrix_diag_part(x_var)
  193.             sqrt_x_var_diag = torch.sqrt(x_var_diag + EPS)
  194.             mu = x_mean / (sqrt_x_var_diag + EPS)
  195.  
  196.             z_mean = sqrt_x_var_diag * softrelu(mu)
  197.             z_var = compute_relu_var(x_var, x_var_diag, mu)
  198.  
  199.         return z_mean, z_var
  200.  
  201.  
  202. x = torch.randn(2, 10)
  203. layer = ReluGaussian(10, 2, certain=True)
  204. layer(x)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement