Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import pymc as mc
- import matplotlib.pyplot as plt
- n = 3
- ndata = 500
- # simulated data
- v = np.random.randint( 0, n, ndata)
- data = (v==0)*(10+ 1*np.random.randn(ndata)) \
- + (v==1)*(-10 + 2*np.random.randn(ndata)) \
- + (v==2)*3*np.random.randn(ndata)
- dd = mc.Dirichlet('dd', theta=(1,)*n)
- category = mc.Categorical('category', p=dd, size=ndata)
- precs = mc.Gamma('precs', alpha=2.5, beta=1, size=n)
- means = mc.Normal('means', [-5, 0, 5], 0.0001, size=n)
- @mc.deterministic
- def mean(category=category, means=means):
- return means[category]
- @mc.deterministic
- def prec(category=category, precs=precs):
- return precs[category]
- obs = mc.Normal('obs', mean, prec, value=data, observed = True)
- model = mc.Model({'dd': dd,
- 'category': category,
- 'precs': precs,
- 'means': means,
- 'obs': obs})
- M = mc.MAP(model)
- M.fit()
- mcmc = mc.MCMC(model)
- mcmc.sample(100000,burn=0,thin=10)
- tmeans = mcmc.trace('means').gettrace()
- tsd = mcmc.trace('precs').gettrace()**-.5
- plt.plot(tmeans)
- #plt.errorbar(range(len(tmeans)), tmeans, yerr=tsd)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement