Advertisement
Guest User

Untitled

a guest
Apr 5th, 2020
196
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.81 KB | None | 0 0
  1. class BayesianLinearRegression(PyroModule):
  2.     def __init__(self, n_input):
  3.         super().__init__()
  4.         self.n_input = n_input
  5.         self.has_intercept = True
  6.        
  7.     def model(self, x, y=None):
  8.         print('model', x, y)
  9.         if self.has_intercept:
  10.             intercept = pyro.sample(
  11.                 'intercept',
  12.                 dist.Normal(0., 10.).expand([1]).to_event(1)
  13.             )
  14.         else:
  15.             intercept = torch.Tensor([0.])
  16.         coefficients = pyro.sample(
  17.             'coefficients',
  18.             dist.Normal(0., 1.).expand([1, self.n_input]).to_event(2)
  19.         )
  20.         sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
  21.         mean = coefficients @ x.t() + intercept
  22.         with pyro.plate("data", x.shape[0]):
  23.             obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
  24.             return obs
  25.  
  26.     def forward(self, *args, **kwargs):
  27.         return self.model(*args, **kwargs)
  28.    
  29.     def guide(self, x, y=None):
  30.         # Make sure that guide returns a value
  31.         print('guide', x, y)
  32.         sigma_loc = pyro.param(
  33.             'sigma_loc',
  34.             torch.rand(1),
  35.             constraint=constraints.positive
  36.         )
  37.         loc = pyro.param(
  38.             'loc',
  39.             torch.zeros([1, self.n_input])
  40.         )
  41.         scale = pyro.param(
  42.             'scale',
  43.             torch.ones([1, self.n_input]),
  44.             constraint=constraints.positive
  45.         )
  46.         if self.has_intercept:
  47.             intercept_loc = pyro.param('intercept_loc', torch.zeros(1))
  48.             intercept_scale = pyro.param(
  49.                 'intercept_scale',
  50.                 torch.ones(1),
  51.                 constraint=constraints.positive
  52.             )
  53.             pyro.sample(
  54.                 'intercept',
  55.                 dist.Normal(intercept_loc, intercept_scale).to_event(1)
  56.             )
  57.        
  58.         pyro.sample('coefficients', dist.Normal(loc, scale).to_event(2))
  59.         return pyro.sample('sigma', dist.Normal(sigma_loc, torch.Tensor([0.05])))
  60.  
  61.    
  62. class StackedModel(nn.Module):
  63.    
  64.     def __init__(self, n_input, initial_multipliers, university_rating_column, university_ratings):
  65.         super().__init__()
  66.         self.logit = nn.Sigmoid()
  67.         self.linear_regression = BayesianLinearRegression(n_input)
  68.        
  69.         self.initial_multipliers = initial_multipliers
  70.         self.university_rating_column = university_rating_column
  71.         self.university_ratings = university_ratings
  72.        
  73.     def model(self, x, y=None):
  74.         self.linear_regression.model(x)  # Use `model` from `self.linear_regression`
  75.    
  76.     def forward(self, *args, **kwargs):
  77.         return self.model(*args, **kwargs)
  78.    
  79.     def guide(self, x, y=None):
  80.         self.linear_regression.guide(x)  # Use `guide` from `self.linear_regression`
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement