Advertisement
Guest User

Untitled

a guest
May 4th, 2015
219
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.77 KB | None | 0 0
  1. def linreg_elbo(X, t, max_iter=100, tol=1e-4, display=False):
  2.  
  3. N = X.shape[0]
  4. D = X.shape[1]
  5. if display:
  6. print('%5s %30s %15s %15s %15s %15s %15s' % ('iter', 'ELBO(alpha,beta,xi,mu,sigma)',
  7. 'norm_alpha_diff', 'norm_beta_diff', 'norm_xi_diff',
  8. 'norm_mu_diff', 'norm_sigma_diff'))
  9. L = []
  10. alpha_old = 1.
  11. beta_old = 1.
  12. xi_old = np.array([1.] * N)
  13. mu_old = np.array([1.] * D)
  14. sigma_old = 1. * eye(D)
  15. elbo_old = compute_elbo(X, t, alpha_old, beta_old, xi_old, mu_old, sigma_old)
  16. L.append(elbo_old)
  17. delta = 1e-8
  18. for it in range(max_iter):
  19. xi_new = np.zeros(N)
  20.  
  21. alpha_denominator = sum(xi_old)
  22. for n in range(N):
  23. x_n = X[n,].reshape(D, 1)
  24. x_nT = x_n.transpose()
  25. mux_n = sum(mu_old * X[n,])
  26. val_n = np.trace(sigma_old.dot(x_n).dot(x_nT)) + (mux_n - t[n]) ** 2
  27. xi_new[n] = max(sqrt(val_n), delta)
  28. alpha_denominator += val_n / xi_old[n]
  29. alpha_new = 2. * N / alpha_denominator
  30.  
  31. beta_new = D / (np.trace(sigma_old) + sum(mu_old ** 2))
  32.  
  33. mu_new = 1. * mu_old
  34. for j in range(D):
  35. mu_j_denominator = beta_new + sum(alpha_new * (X[:,j] ** 2) / xi_new)
  36. mu_j_nominator = sum(alpha_new * X[:,j] * (t + mu_old[j] * X[:,j] - X.dot(mu_old)) / xi_new)
  37. mu_new[j] = mu_j_nominator / mu_j_denominator
  38.  
  39. sigma_new = beta_new * eye(D)
  40. for n in range(N):
  41. x_n = X[n,].reshape(D, 1)
  42. x_nT = x_n.transpose()
  43. sigma_new += alpha_new * x_n.dot(x_nT) / xi_new[n]
  44. sigma_new = inv(sigma_new)
  45.  
  46. elbo_new = compute_elbo(X, t, alpha_new, beta_new, xi_new, mu_new, sigma_new)
  47. L.append(elbo_new)
  48.  
  49. alpha_diff = abs(alpha_new - alpha_old)
  50. beta_diff = abs(beta_new - beta_old)
  51. xi_diff = norm(xi_new - xi_old, np.inf)
  52. mu_diff = norm(mu_new - mu_old, np.inf)
  53. sigma_diff = norm(sigma_new - sigma_old, np.inf)
  54.  
  55. if display:
  56. print('%5d %30.15f %15.8f %15.8f %15.8f %15.8f %15.8f' %
  57. (it+1, L[it+1], alpha_diff, beta_diff, xi_diff, mu_diff, sigma_diff))
  58.  
  59. if abs(elbo_new - elbo_old) < tol:
  60. break
  61.  
  62. alpha_old = 1. * alpha_new
  63. beta_old = 1. * beta_new
  64. xi_old = 1. * xi_new
  65. mu_old = 1. * mu_new
  66. sigma_old = 1. * sigma_new
  67. elbo_old = 1. * elbo_new
  68.  
  69. result = {}
  70. result['alpha'] = alpha_new
  71. result['beta'] = beta_new
  72. result['mu'] = mu_new
  73. result['sigma'] = sigma_new
  74. result['L'] = L
  75. return result
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement