Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def change_point_Guide(x, piece, mask = None):
- n_iter = x.shape[0]
- if mask is None:
- change_point_probs_prior = torch.ones(n_iter)
- elif mask.any():
- change_point_probs_prior = mask.float()
- else:
- return n_iter-1
- #change_point_probs_concentration = pyro.param('change_point_probs_concentration_{}'.format(piece), change_point_probs_prior, constraint=constraints.positive)
- #change_point_probs = pyro.sample('change_point_probs_{}'.format(piece), dist.Dirichlet(change_point_probs_concentration))
- change_point_probs = pyro.param('change_point_probs_{}'.format(piece), change_point_probs_prior, constraint=constraints.simplex)
- change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs),infer={'enumerate': 'parallel'})
- #change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs),infer={'enumerate': 'sequential'})
- #change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs))
- return change_point
- def sample_linear_regressionGuide():
- slope_loc = pyro.param('slope_loc', torch.randn(1))
- slope_scale = pyro.param('slope_scale', torch.ones(1), constraint=constraints.positive)
- intercept_loc = pyro.param('intercept_loc', torch.randn(1))
- intercept_scale = pyro.param('intercept_scale', torch.ones(1), constraint=constraints.positive)
- noise_std_loc = pyro.param('noise_std_loc', torch.randn(1))
- noise_std_scale = pyro.param('noise_std_scale', torch.ones(1), constraint=constraints.positive)
- slope = pyro.sample('slope', dist.Normal(slope_loc,slope_scale))
- intercept = pyro.sample('intercept', dist.Normal(intercept_loc,intercept_scale))
- noise_std = pyro.sample('noise_std', dist.LogNormal(noise_std_loc,noise_std_scale))
- return slope, intercept, noise_std
- def sample_slope_and_noise_stdGuide(piece):
- slope_loc = pyro.param('slope_loc_{}'.format(piece), torch.randn(1))
- slope_scale = pyro.param('slope_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
- noise_std_loc = pyro.param('noise_std_loc_{}'.format(piece), torch.randn(1))
- noise_std_scale = pyro.param('noise_std_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
- slope = pyro.sample('slope_{}'.format(piece), dist.Normal(slope_loc,slope_scale))
- noise_std = pyro.sample('noise_std_{}'.format(piece), dist.LogNormal(noise_std_loc,noise_std_scale))
- return slope, noise_std
- def piecewise_regressionGuide(x, y, n_pieces = 3):
- """
- piecewise regression where last piece has slope 0
- """
- N = x.shape[0]
- lines = []
- for piece in pyro.plate('pieces', n_pieces):
- if piece == 0:
- lines.append(sample_linear_regressionGuide())
- change_point_Guide(x, piece)
- elif piece < n_pieces-1:
- slope, noise_std = sample_slope_and_noise_stdGuide(piece)
- lines.append((slope,noise_std))
- change_point_Guide(x, piece)
- else:
- noise_std_loc = pyro.param('noise_std_loc_{}'.format(piece), torch.randn(1))
- noise_std_scale = pyro.param('noise_std_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
- noise_std = pyro.sample('noise_std_{}'.format(piece), dist.LogNormal(noise_std_loc,noise_std_scale))
- lines.append(noise_std)
- return lines
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement