Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import scipy.stats
- import scipy.optimize
- np.random.seed(1)
- # input
- cov0 = np.array([[1.19, -0.08, 0., -0.08], [-0.08, 0.68, 0.02, -0.04],
- [0., 0.02, 0.9, -0.05], [-0.08, -0.04, -0.05, 0.65]])
- mu0 = np.array([1, 2, 3, 4])
- print('input', mu0)
- print(cov0)
- N1 = 100000
- N2 = 120000
- xs1 = np.random.multivariate_normal(mu0, cov0, size=N1)
- xs2 = np.random.multivariate_normal(mu0, cov0, size=N2)
- proj1 = np.array([[1, 0, 0, 0], [0, 1, 0, 0]])
- proj2 = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
- # measurements
- xmeas1 = np.dot(xs1, proj1.T)
- xmeas2 = np.dot(xs2, proj2.T)
- ndim = 4
- def like(p, getVecMat=False):
- # likelihood of two dataset given mean vector and covar mat
- mean = p[:ndim]
- covar = np.zeros((ndim, ndim))
- covar[np.arange(ndim), np.arange(ndim)] = p[ndim:2 * ndim]
- ix, iy = np.tril_indices(ndim, -1)
- covar[ix, iy] = p[2 * ndim:]
- covar[iy, ix] = p[2 * ndim:]
- eigv = scipy.linalg.eigh(covar)[0]
- if np.any(eigv <= 0):
- #print('oops')
- return 1e30
- if getVecMat:
- return mean, covar
- c1 = np.dot(proj1, np.dot(covar, proj1.T))
- c2 = np.dot(proj2, np.dot(covar, proj2.T))
- m1 = np.dot(mean, proj1.T)
- m2 = np.dot(mean, proj2.T)
- L1 = scipy.stats.multivariate_normal(m1, c1).logpdf(xmeas1)
- L2 = scipy.stats.multivariate_normal(m2, c2).logpdf(xmeas2)
- lprior = scipy.stats.wishart(df=4, scale=np.ones(4)).logpdf(covar)
- logp = L1.sum() + L2.sum() + lprior
- return -logp
- R = scipy.optimize.minimize(
- like, np.r_[np.ones(4), np.ones(4), np.zeros(6)], method='Nelder-Mead')
- R = scipy.optimize.minimize(
- like, R['x'])
- retmean, retcov = like(R['x'], True)
- print('output', np.round(retmean, 3))
- print(np.round(retcov, 3))
- $ python xx.py
- input [1 2 3 4]
- [[ 1.19 -0.08 0. -0.08]
- [-0.08 0.68 0.02 -0.04]
- [ 0. 0.02 0.9 -0.05]
- [-0.08 -0.04 -0.05 0.65]]
- output [1.001 1.998 3.002 4. ]
- [[ 1.188 -0.084 -0.299 -0.056]
- [-0.084 0.682 0.019 -0.041]
- [-0.299 0.019 0.901 -0.053]
- [-0.056 -0.041 -0.053 0.652]]
Advertisement
Add Comment
Please, Sign In to add comment