Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class BayesianLinearRegression(PyroModule):
- def __init__(self, n_input):
- super().__init__()
- self.n_input = n_input
- self.has_intercept = True
- def model(self, x, y=None):
- print('model', x, y)
- if self.has_intercept:
- intercept = pyro.sample(
- 'intercept',
- dist.Normal(0., 10.).expand([1]).to_event(1)
- )
- else:
- intercept = torch.Tensor([0.])
- coefficients = pyro.sample(
- 'coefficients',
- dist.Normal(0., 1.).expand([1, self.n_input]).to_event(2)
- )
- sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
- mean = coefficients @ x.t() + intercept
- with pyro.plate("data", x.shape[0]):
- obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
- return obs
- def forward(self, *args, **kwargs):
- return self.model(*args, **kwargs)
- def guide(self, x, y=None):
- # Make sure that guide returns a value
- print('guide', x, y)
- sigma_loc = pyro.param(
- 'sigma_loc',
- torch.rand(1),
- constraint=constraints.positive
- )
- loc = pyro.param(
- 'loc',
- torch.zeros([1, self.n_input])
- )
- scale = pyro.param(
- 'scale',
- torch.ones([1, self.n_input]),
- constraint=constraints.positive
- )
- if self.has_intercept:
- intercept_loc = pyro.param('intercept_loc', torch.zeros(1))
- intercept_scale = pyro.param(
- 'intercept_scale',
- torch.ones(1),
- constraint=constraints.positive
- )
- pyro.sample(
- 'intercept',
- dist.Normal(intercept_loc, intercept_scale).to_event(1)
- )
- pyro.sample('coefficients', dist.Normal(loc, scale).to_event(2))
- return pyro.sample('sigma', dist.Normal(sigma_loc, torch.Tensor([0.05])))
- class StackedModel(nn.Module):
- def __init__(self, n_input, initial_multipliers, university_rating_column, university_ratings):
- super().__init__()
- self.logit = nn.Sigmoid()
- self.linear_regression = BayesianLinearRegression(n_input)
- self.initial_multipliers = initial_multipliers
- self.university_rating_column = university_rating_column
- self.university_ratings = university_ratings
- def model(self, x, y=None):
- self.linear_regression.model(x) # Use `model` from `self.linear_regression`
- def forward(self, *args, **kwargs):
- return self.model(*args, **kwargs)
- def guide(self, x, y=None):
- self.linear_regression.guide(x) # Use `guide` from `self.linear_regression`
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement