Advertisement
Guest User

Untitled

a guest
Oct 7th, 2013
265
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.16 KB | None | 0 0
  1. import numpy as np
  2. import pymc as mc
  3. import matplotlib.pyplot as plt
  4.  
  5. n = 3
  6. ndata = 500
  7.  
  8. # simulated data
  9. v = np.random.randint( 0, n, ndata)
  10. data = (v==0)*(10+ 1*np.random.randn(ndata)) \
  11.        + (v==1)*(-10 + 2*np.random.randn(ndata)) \
  12.        + (v==2)*3*np.random.randn(ndata)
  13.  
  14. dd = mc.Dirichlet('dd', theta=(1,)*n)
  15. category = mc.Categorical('category', p=dd, size=ndata)
  16. precs = mc.Gamma('precs', alpha=2.5, beta=1, size=n)
  17. means = mc.Normal('means', [-5, 0, 5], 0.0001, size=n)
  18.  
  19. @mc.deterministic
  20. def mean(category=category, means=means):
  21.     return means[category]
  22.  
  23. @mc.deterministic
  24. def prec(category=category, precs=precs):
  25.     return precs[category]
  26.  
  27. obs = mc.Normal('obs', mean, prec, value=data, observed = True)
  28.  
  29. model = mc.Model({'dd': dd,
  30.                   'category': category,
  31.                   'precs': precs,
  32.                   'means': means,
  33.                   'obs': obs})
  34.  
  35. M = mc.MAP(model)
  36. M.fit()
  37. mcmc = mc.MCMC(model)
  38. mcmc.sample(100000,burn=0,thin=10)
  39.  
  40. tmeans = mcmc.trace('means').gettrace()
  41. tsd = mcmc.trace('precs').gettrace()**-.5
  42. plt.plot(tmeans)
  43. #plt.errorbar(range(len(tmeans)), tmeans, yerr=tsd)
  44. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement