Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from scipy.stats import dirichlet
- from scipy.special import psi, polygamma
- eps = 1e-100
- max_iter = 10
- def parameter_estimation(theta, old_alpha):
- """
- estimating a dirichlet parameter given a set of multinomial parameters.
- Dirichlet parameter alpha can be decomposed into precision s and mean m
- where s = \sum_k alpha_k, m = alpha/s, and alpha = s * m.
- s and each entry of m must be positive.
- Argument
- theta : a set of multinomial, N x K matrix (N = # of observation, K = dimension of dirichlet)
- s : precision parameter (scala)
- m : mean parameter (K-vector)
- """
- log_p_bar = np.mean(np.log(theta), 0) #sufficient statistics
- for j in xrange(max_iter):
- digamma_alpha = psi(np.sum(old_alpha)) + log_p_bar
- old_alpha = np.exp(digamma_alpha) + 0.5
- old_alpha[old_alpha<0.6] = 1.0/(- digamma_alpha[old_alpha<0.6] + psi(1.))
- for i in xrange(max_iter):
- new_alpha = old_alpha - (psi(old_alpha)-digamma_alpha)/(polygamma(1,old_alpha))
- old_alpha = new_alpha
- return new_alpha
- if __name__ == '__main__':
- N = 100
- K = 5 # dimension of Dirichlet
- _alpha = np.random.gamma(1,1) * np.random.dirichlet([1.]*K) # ground truth alpha
- obs = np.random.dirichlet(_alpha, size=N) + eps # draw N samples from Dir(_alpha)
- obs /= np.sum(obs, 1)[:,np.newaxis] #renormalize for added eps
- initial_alpha = np.random.dirichlet([1.]*K) * np.random.gamma(1,1) # first guess on alpha
- g_ll = 0 #log-likelihood with ground truth parameter
- ll = 0 #log-likelihood with initial guess of alpha
- for i in xrange(N):
- g_ll += dirichlet.logpdf(obs[i], _alpha)
- ll += dirichlet.logpdf(obs[i], initial_alpha)
- print 'likelihood p(obs|_alpha) = %.3f' % g_ll
- print 'likelihood p(obs|initial_alpha) = %.3f' % ll
- #estimating
- est_alpha = parameter_estimation(obs, initial_alpha)
- ll = 0 #log-likelihood with estimated parameter
- for i in xrange(N):
- ll += dirichlet.logpdf(obs[i], est_alpha)
- print 'likelihood p(obs|est_alpha) = %.3f' % ll
- print 'likelihood difference = %.3f' % (g_ll - ll)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement