Advertisement
Guest User

Untitled

a guest
Mar 28th, 2020
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.20 KB | None | 0 0
  1. from collections import defaultdict
  2. from typing import Dict
  3.  
  4. import matplotlib.pyplot as plt
  5. import pyro
  6. import pyro.distributions as dist
  7. import torch
  8. import numpy as np
  9.  
  10. from pyro.infer.mcmc import mcmc_kernel
  11. from pyro.infer.mcmc.util import initialize_model
  12.  
  13.  
  14. class MetropolisHastings(mcmc_kernel.MCMCKernel):
  15.     """Implementation of Metropolis Hastings sampler for MCMC."""
  16.  
  17.     def __init__(self, model, proposal_dist, priors):
  18.         """Inits MetropolisHastings.
  19.        :param model: Probabilistic model to estimate (likelihood).
  20.        :param proposal_dist: Distribution to generate next parameter value.
  21.        :param priors: Prior distribution over parameter space.
  22.        """
  23.         self.model = model
  24.         self._proposal_dist = proposal_dist
  25.         self._priors = priors
  26.  
  27.         self._model_args = None
  28.         self._model_kwargs = None
  29.  
  30.         self._initial_params = None
  31.  
  32.         self._step = 0
  33.         self._warmup_steps = None
  34.  
  35.         self._generated_samples = {
  36.             'accepted': defaultdict(list),
  37.             'rejected': defaultdict(list),
  38.             'counts': {
  39.                 'accepted': 0,
  40.                 'rejected': 0,
  41.             }
  42.         }
  43.  
  44.     @property
  45.     def initial_params(self):
  46.         return self._initial_params
  47.  
  48.     @initial_params.setter
  49.     def initial_params(self, params):
  50.         self._initial_params = params
  51.  
  52.     def logging(self):
  53.         """Provides statistics for progress bar."""
  54.         return {
  55.             '#accepted': self._generated_samples['counts']['accepted'],
  56.             '#rejected': self._generated_samples['counts']['rejected'],
  57.         }
  58.  
  59.     def setup(self, warmup_steps, *args, **kwargs):
  60.         """Sets up the sampler."""
  61.         self._warmup_steps = warmup_steps
  62.  
  63.         init_params, _, _, _ = initialize_model(
  64.             self.model, args, kwargs,
  65.         )
  66.         if self._initial_params is None:
  67.             self._initial_params = init_params
  68.  
  69.         self._model_args = args
  70.         self._model_kwargs = kwargs
  71.  
  72.     def _next_parameters_proposal(
  73.         self,
  74.         curr_params: Dict[str, torch.Tensor]
  75.     ) -> Dict[str, torch.Tensor]:
  76.         """Samples new parameters from the proposal distribution.
  77.        Use `pyro.sample` and include current step in the sample name,
  78.        eg. `c_0`, `c_1`...
  79.        :param curr_params: Current parameter values.
  80.        :return: New  parameter values.
  81.        """
  82.         new_params = {k: pyro.sample(f'{k}_{self._step}', self._proposal_dist(v)) for (k, v) in curr_params.items()}
  83.         return new_params
  84.  
  85.  
  86.     def _get_log_likelihood(self, params):
  87.         """Calculates the log-likelihood of the provided params.
  88.        Use `pyro.condition` and `pyro.poutine.trace`.
  89.        """
  90.         conditioned = pyro.condition(self.model, data=params)
  91.         traced = pyro.poutine.trace(conditioned).get_trace(*self._model_args, **self._model_kwargs)
  92.         return traced.log_prob_sum()
  93.  
  94.     def _get_log_priors(self, params):
  95.         """Calculates log-prob of prior distribution for the provided params."""
  96.         vals = [self._priors[pname](pval) for pname, pval in params.items()]
  97.         return torch.tensor(vals).log().sum()
  98.  
  99.     def _should_accept(self, curr_lp, new_lp):
  100.         """Decides whether to accept or reject the new params configuration."""
  101.         curr_post = self._get_log_likelihood(curr_lp) + self._get_log_priors(curr_lp)
  102.         new_post = self._get_log_likelihood(new_lp) + self._get_log_priors(new_lp)
  103.         ratio = (new_post - curr_post).exp().item()
  104.         p = min(1, ratio)
  105.         r = np.random.uniform()
  106.         return r < p
  107.  
  108.     def sample(self, params):
  109.         """Returns new params configuration given current.
  110.        The man body of the Metropolis Hastings sampling algorithm.
  111.        """
  112.         # Implement this
  113.         new_params = self._next_parameters_proposal(curr_params=params)
  114.         accept = self._should_accept(curr_lp=params, new_lp=new_params)
  115.  
  116.         if self._step > self._warmup_steps:
  117.             for pname in params.keys():
  118.                 k = 'accepted' if accept else 'rejected'
  119.                 self._generated_samples[k][pname].append(
  120.                     (self._step, new_params[pname].item())
  121.                 )
  122.                 self._generated_samples['counts'][k] += 1
  123.  
  124.         self._step += 1
  125.  
  126.         return new_params.copy() if accept else params.copy()
  127.  
  128.  
  129. def plot_accepted_rejected_samples(kernel):
  130.     """Plots samples generated by the `kernel` dung estimation."""
  131.     samples = kernel._generated_samples
  132.  
  133.     fig, axs = plt.subplots(ncols=2, figsize=(15, 5))
  134.  
  135.     for name, ax in zip(('c', 'std'), axs.ravel()):
  136.         x_rej, y_rej = zip(*samples['rejected'][name])
  137.         ax.plot(
  138.             x_rej, y_rej,
  139.             marker='o', linestyle='',
  140.             color='red', label='Rejected',
  141.             alpha=0.5
  142.         )
  143.  
  144.         x_acc, y_acc = zip(*samples['accepted'][name])
  145.         ax.plot(
  146.             x_acc, y_acc,
  147.             marker='X', linestyle='',
  148.             color='green', label='Accepted',
  149.             ms=10
  150.         )
  151.  
  152.         ax.set_xlabel('Iteration')
  153.         ax.set_ylabel('Value')
  154.         ax.set_title(name)
  155.         ax.legend()
  156.  
  157.     plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement