# piecewise regression guide

a guest
Mar 26th, 2020
75
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. def change_point_Guide(x, piece, mask = None):
2.     n_iter = x.shape[0]
4.         change_point_probs_prior = torch.ones(n_iter)
7.     else:
8.         return n_iter-1
9.     #change_point_probs_concentration = pyro.param('change_point_probs_concentration_{}'.format(piece), change_point_probs_prior, constraint=constraints.positive)
10.     #change_point_probs = pyro.sample('change_point_probs_{}'.format(piece), dist.Dirichlet(change_point_probs_concentration))
11.     change_point_probs = pyro.param('change_point_probs_{}'.format(piece), change_point_probs_prior, constraint=constraints.simplex)
12.     change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs),infer={'enumerate': 'parallel'})
13.     #change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs),infer={'enumerate': 'sequential'})
14.     #change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs))
15.     return change_point
16.
17. def sample_linear_regressionGuide():
18.     slope_loc = pyro.param('slope_loc', torch.randn(1))
19.     slope_scale = pyro.param('slope_scale', torch.ones(1), constraint=constraints.positive)
20.
21.     intercept_loc = pyro.param('intercept_loc', torch.randn(1))
22.     intercept_scale = pyro.param('intercept_scale', torch.ones(1), constraint=constraints.positive)
23.
24.     noise_std_loc = pyro.param('noise_std_loc', torch.randn(1))
25.     noise_std_scale = pyro.param('noise_std_scale', torch.ones(1), constraint=constraints.positive)
26.
27.     slope = pyro.sample('slope', dist.Normal(slope_loc,slope_scale))
28.     intercept = pyro.sample('intercept', dist.Normal(intercept_loc,intercept_scale))
29.     noise_std = pyro.sample('noise_std', dist.LogNormal(noise_std_loc,noise_std_scale))
30.     return slope, intercept, noise_std
31.
32. def sample_slope_and_noise_stdGuide(piece):
33.     slope_loc = pyro.param('slope_loc_{}'.format(piece), torch.randn(1))
34.     slope_scale = pyro.param('slope_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
35.
36.     noise_std_loc = pyro.param('noise_std_loc_{}'.format(piece), torch.randn(1))
37.     noise_std_scale = pyro.param('noise_std_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
38.
39.     slope = pyro.sample('slope_{}'.format(piece), dist.Normal(slope_loc,slope_scale))
40.     noise_std = pyro.sample('noise_std_{}'.format(piece), dist.LogNormal(noise_std_loc,noise_std_scale))
41.     return slope, noise_std
42.
43. def piecewise_regressionGuide(x, y, n_pieces = 3):
44.     """
45.    piecewise regression where last piece has slope 0
46.    """
47.     N = x.shape[0]
48.     lines = []
49.     for piece in pyro.plate('pieces', n_pieces):
50.         if piece == 0:
51.             lines.append(sample_linear_regressionGuide())
52.             change_point_Guide(x, piece)
53.         elif piece < n_pieces-1:
54.             slope, noise_std = sample_slope_and_noise_stdGuide(piece)
55.             lines.append((slope,noise_std))
56.             change_point_Guide(x, piece)
57.         else:
58.             noise_std_loc = pyro.param('noise_std_loc_{}'.format(piece), torch.randn(1))
59.             noise_std_scale = pyro.param('noise_std_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
60.             noise_std = pyro.sample('noise_std_{}'.format(piece), dist.LogNormal(noise_std_loc,noise_std_scale))
61.             lines.append(noise_std)
62.     return lines
RAW Paste Data