Tranvick

linreg_elbo

May 5th, 2015
359
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.19 KB | None | 0 0
  1. def linreg_elbo(X, t, max_iter=100, tol=1e-4, display=False):
  2.     delta = 1e-8
  3.     n, d = X.shape
  4.     alpha = beta = 1.
  5.     xi = np.ones(n)
  6.     mu = np.ones(d)
  7.     sigma = np.identity(d)
  8.     L = [compute_elbo(X, t, alpha, beta, xi, mu, sigma)]
  9.     if display:
  10.         print('%9s %15s' % ('iteration', 'F'))
  11.         print('%9d %15.11f' % (0, L[-1]))
  12.    
  13.     for iteration in xrange(1, max_iter + 1):
  14.         xi_old = xi
  15.         xi = np.array([sigma.dot(x.reshape((-1, 1))).dot(x.reshape(1, -1)).trace() + mu.dot(x) ** 2 - 2 * mu.dot(x) * t_n + t_n ** 2
  16.               for x, t_n in zip(X, t)])
  17.         xi = np.sqrt(xi)
  18.         xi[xi < delta] = delta
  19.        
  20.         alpha = 2 * n / (np.sum(xi_old) + np.sum(xi ** 2 / xi_old))
  21.         beta = d / (sigma.trace() + mu.dot(mu))
  22.        
  23.         sigma = inv(alpha * ((X).T / xi).dot(X) + beta * np.identity(d))
  24.         mu = alpha * sigma.dot(np.sum(X.T * t / xi, axis=1))
  25.         L.append(compute_elbo(X, t, alpha, beta, xi, mu, sigma))
  26.         if display:
  27.             print('%9d %15.11f' % (iteration, L[-1]))
  28.         if abs(L[-1] - L[-2]) <= tol:
  29.             break
  30.    
  31.     return {'alpha': alpha, 'beta': beta, 'mu': mu, 'sigma': sigma, 'L': L}
Advertisement
Add Comment
Please, Sign In to add comment