Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from torch import nn
- class LinearGaussian(nn.Module):
- def __init__(self, in_features, out_features, certain=False,
- deterministic=True):
- """
- Applies linear transformation y = xA^T + b
- A and b are Gaussian random variables
- :param in_features: input dimension
- :param out_features: output dimension
- :param certain: if true, than x is equal to its mean and has no variance
- """
- super().__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.W = nn.Parameter(torch.Tensor(in_features, out_features))
- self.bias = nn.Parameter(torch.Tensor(out_features))
- self.W_logvar = nn.Parameter(torch.Tensor(in_features, out_features))
- self.bias_logvar = nn.Parameter(torch.Tensor(out_features))
- self.__initialize_weights()
- self.__construct_priors()
- self.certain = certain
- self.deterministic = deterministic
- self.mean_forward = False
- self.zero_mean = False
- def __initialize_weights(self):
- nn.init.xavier_normal_(self.W)
- nn.init.normal_(self.bias)
- nn.init.uniform_(self.W_logvar, a=-10, b=-7)
- nn.init.uniform_(self.bias_logvar, a=-10, b=-7)
- def __construct_priors(self):
- self.W_mean_prior = nn.Parameter(torch.zeros_like(self.W),
- requires_grad=False)
- self.W_var_prior = nn.Parameter(torch.ones_like(self.W_logvar) * 0.1,
- requires_grad=False)
- self.bias_mean_prior = nn.Parameter(torch.zeros_like(self.bias),
- requires_grad=False)
- self.bias_var_prior = nn.Parameter(
- torch.ones_like(self.bias_logvar) * 0.1,
- requires_grad=False)
- def compute_kl(self):
- weights_kl = kl_gaussian(self.W, torch.exp(self.W_logvar),
- self.W_mean_prior, self.W_var_prior)
- bias_kl = kl_gaussian(self.bias, torch.exp(self.bias_logvar),
- self.bias_mean_prior, self.bias_var_prior)
- return weights_kl + bias_kl
- def set_flag(self, flag_name, value):
- setattr(self, flag_name, value)
- for m in self.children():
- if hasattr(m, 'set_flag'):
- m.set_flag(flag_name, value)
- def forward(self, x):
- """
- Compute expectation and variance after linear transform
- y = xA^T + b
- :param x: input, size [batch, in_features]
- :return: tuple (y_mean, y_var) for deterministic mode:, shapes:
- y_mean: [batch, out_features]
- y_var: [batch, out_features, out_features]
- tuple (sample, None) for MCVI mode,
- sample : [batch, out_features] - local reparametrization of output
- """
- x = self.__apply_activation(x)
- if self.zero_mean:
- return self.__zero_mean_forward(x)
- elif self.mean_forward:
- return self.__mean_forward(x)
- elif self.deterministic:
- return self.__det_forward(x)
- else:
- return self.__mcvi_forward(x)
- def __mcvi_forward(self, x):
- W_var = torch.exp(self.W_logvar)
- bias_var = torch.exp(self.bias_logvar)
- if self.certain:
- x_mean = x
- x_var = None
- else:
- x_mean = x[0]
- x_var = x[1]
- y_mean = F.linear(x_mean, self.W.t()) + self.bias
- if self.certain or not self.deterministic:
- xx = x_mean * x_mean
- y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var)
- else:
- y_var = compute_linear_var(x_mean, x_var, self.W, W_var, self.bias,
- bias_var)
- dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var)
- sample = dst.rsample()
- return sample, None
- def __det_forward(self, x):
- W_var = torch.exp(self.W_logvar)
- bias_var = torch.exp(self.bias_logvar)
- if self.certain:
- x_mean = x
- x_var = None
- else:
- x_mean = x[0]
- x_var = x[1]
- y_mean = F.linear(x_mean, self.W.t()) + self.bias
- if self.certain:
- xx = x_mean * x_mean
- y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var)
- else:
- y_var = compute_linear_var(x_mean, x_var, self.W, W_var, self.bias,
- bias_var)
- return y_mean, y_var
- def __mean_forward(self, x):
- if not isinstance(x, tuple):
- x_mean = x
- else:
- x_mean = x[0]
- y_mean = F.linear(x_mean, self.W.t()) + self.bias
- return y_mean, None
- def __zero_mean_forward(self, x):
- if not isinstance(x, tuple):
- x_mean = x
- x_var = None
- else:
- x_mean = x[0]
- x_var = x[1]
- y_mean = F.linear(x_mean, torch.zeros_like(self.W).t()) + self.bias
- W_var = torch.exp(self.W_logvar)
- bias_var = torch.exp(self.bias_logvar)
- if x_var is None:
- xx = x_mean * x_mean
- y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var)
- else:
- y_var = compute_linear_var(x_mean, x_var, torch.zeros_like(self.W),
- W_var, self.bias, bias_var)
- if self.deterministic:
- return y_mean, y_var
- else:
- dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var)
- sample = dst.rsample()
- return sample, None
- def __apply_activation(self, x):
- return x
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'in_features=' + str(self.in_features) \
- + ', out_features=' + str(self.out_features) + ')'
- class ReluGaussian(LinearGaussian):
- def __apply_activation(self, x):
- print("i am in Relu")
- x_mean = x[0]
- x_var = x[1]
- if not self.deterministic:
- z_mean = F.relu(x_mean)
- z_var = None
- else:
- x_var_diag = matrix_diag_part(x_var)
- sqrt_x_var_diag = torch.sqrt(x_var_diag + EPS)
- mu = x_mean / (sqrt_x_var_diag + EPS)
- z_mean = sqrt_x_var_diag * softrelu(mu)
- z_var = compute_relu_var(x_var, x_var_diag, mu)
- return z_mean, z_var
- x = torch.randn(2, 10)
- layer = ReluGaussian(10, 2, certain=True)
- layer(x)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement