Guest User

piecewise regression guide

a guest
Mar 26th, 2020
75
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def change_point_Guide(x, piece, mask = None):
  2.     n_iter = x.shape[0]
  3.     if mask is None:
  4.         change_point_probs_prior = torch.ones(n_iter)
  5.     elif mask.any():
  6.         change_point_probs_prior = mask.float()
  7.     else:
  8.         return n_iter-1
  9.     #change_point_probs_concentration = pyro.param('change_point_probs_concentration_{}'.format(piece), change_point_probs_prior, constraint=constraints.positive)
  10.     #change_point_probs = pyro.sample('change_point_probs_{}'.format(piece), dist.Dirichlet(change_point_probs_concentration))
  11.     change_point_probs = pyro.param('change_point_probs_{}'.format(piece), change_point_probs_prior, constraint=constraints.simplex)
  12.     change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs),infer={'enumerate': 'parallel'})
  13.     #change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs),infer={'enumerate': 'sequential'})
  14.     #change_point = pyro.sample('change_point_{}'.format(piece), dist.Categorical(probs=change_point_probs))
  15.     return change_point
  16.  
  17. def sample_linear_regressionGuide():
  18.     slope_loc = pyro.param('slope_loc', torch.randn(1))
  19.     slope_scale = pyro.param('slope_scale', torch.ones(1), constraint=constraints.positive)
  20.    
  21.     intercept_loc = pyro.param('intercept_loc', torch.randn(1))
  22.     intercept_scale = pyro.param('intercept_scale', torch.ones(1), constraint=constraints.positive)
  23.    
  24.     noise_std_loc = pyro.param('noise_std_loc', torch.randn(1))
  25.     noise_std_scale = pyro.param('noise_std_scale', torch.ones(1), constraint=constraints.positive)
  26.  
  27.     slope = pyro.sample('slope', dist.Normal(slope_loc,slope_scale))
  28.     intercept = pyro.sample('intercept', dist.Normal(intercept_loc,intercept_scale))
  29.     noise_std = pyro.sample('noise_std', dist.LogNormal(noise_std_loc,noise_std_scale))
  30.     return slope, intercept, noise_std
  31.  
  32. def sample_slope_and_noise_stdGuide(piece):
  33.     slope_loc = pyro.param('slope_loc_{}'.format(piece), torch.randn(1))
  34.     slope_scale = pyro.param('slope_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
  35.    
  36.     noise_std_loc = pyro.param('noise_std_loc_{}'.format(piece), torch.randn(1))
  37.     noise_std_scale = pyro.param('noise_std_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
  38.  
  39.     slope = pyro.sample('slope_{}'.format(piece), dist.Normal(slope_loc,slope_scale))
  40.     noise_std = pyro.sample('noise_std_{}'.format(piece), dist.LogNormal(noise_std_loc,noise_std_scale))
  41.     return slope, noise_std
  42.  
  43. def piecewise_regressionGuide(x, y, n_pieces = 3):
  44.     """
  45.    piecewise regression where last piece has slope 0
  46.    """
  47.     N = x.shape[0]
  48.     lines = []
  49.     for piece in pyro.plate('pieces', n_pieces):
  50.         if piece == 0:
  51.             lines.append(sample_linear_regressionGuide())
  52.             change_point_Guide(x, piece)
  53.         elif piece < n_pieces-1:
  54.             slope, noise_std = sample_slope_and_noise_stdGuide(piece)
  55.             lines.append((slope,noise_std))
  56.             change_point_Guide(x, piece)
  57.         else:
  58.             noise_std_loc = pyro.param('noise_std_loc_{}'.format(piece), torch.randn(1))
  59.             noise_std_scale = pyro.param('noise_std_scale_{}'.format(piece), torch.ones(1), constraint=constraints.positive)
  60.             noise_std = pyro.sample('noise_std_{}'.format(piece), dist.LogNormal(noise_std_loc,noise_std_scale))
  61.             lines.append(noise_std)
  62.     return lines
RAW Paste Data