Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def linreg_elbo(X, t, max_iter=100, tol=1e-4, display=False):
- N = X.shape[0]
- D = X.shape[1]
- if display:
- print('%5s %30s %15s %15s %15s %15s %15s' % ('iter', 'ELBO(alpha,beta,xi,mu,sigma)',
- 'norm_alpha_diff', 'norm_beta_diff', 'norm_xi_diff',
- 'norm_mu_diff', 'norm_sigma_diff'))
- L = []
- alpha_old = 1.
- beta_old = 1.
- xi_old = np.array([1.] * N)
- mu_old = np.array([1.] * D)
- sigma_old = 1. * eye(D)
- elbo_old = compute_elbo(X, t, alpha_old, beta_old, xi_old, mu_old, sigma_old)
- L.append(elbo_old)
- delta = 1e-8
- for it in range(max_iter):
- xi_new = np.zeros(N)
- alpha_denominator = sum(xi_old)
- for n in range(N):
- x_n = X[n,].reshape(D, 1)
- x_nT = x_n.transpose()
- mux_n = sum(mu_old * X[n,])
- val_n = np.trace(sigma_old.dot(x_n).dot(x_nT)) + (mux_n - t[n]) ** 2
- xi_new[n] = max(sqrt(val_n), delta)
- alpha_denominator += val_n / xi_old[n]
- alpha_new = 2. * N / alpha_denominator
- beta_new = D / (np.trace(sigma_old) + sum(mu_old ** 2))
- mu_new = 1. * mu_old
- for j in range(D):
- mu_j_denominator = beta_new + sum(alpha_new * (X[:,j] ** 2) / xi_new)
- mu_j_nominator = sum(alpha_new * X[:,j] * (t + mu_old[j] * X[:,j] - X.dot(mu_old)) / xi_new)
- mu_new[j] = mu_j_nominator / mu_j_denominator
- sigma_new = beta_new * eye(D)
- for n in range(N):
- x_n = X[n,].reshape(D, 1)
- x_nT = x_n.transpose()
- sigma_new += alpha_new * x_n.dot(x_nT) / xi_new[n]
- sigma_new = inv(sigma_new)
- elbo_new = compute_elbo(X, t, alpha_new, beta_new, xi_new, mu_new, sigma_new)
- L.append(elbo_new)
- alpha_diff = abs(alpha_new - alpha_old)
- beta_diff = abs(beta_new - beta_old)
- xi_diff = norm(xi_new - xi_old, np.inf)
- mu_diff = norm(mu_new - mu_old, np.inf)
- sigma_diff = norm(sigma_new - sigma_old, np.inf)
- if display:
- print('%5d %30.15f %15.8f %15.8f %15.8f %15.8f %15.8f' %
- (it+1, L[it+1], alpha_diff, beta_diff, xi_diff, mu_diff, sigma_diff))
- if abs(elbo_new - elbo_old) < tol:
- break
- alpha_old = 1. * alpha_new
- beta_old = 1. * beta_new
- xi_old = 1. * xi_new
- mu_old = 1. * mu_new
- sigma_old = 1. * sigma_new
- elbo_old = 1. * elbo_new
- result = {}
- result['alpha'] = alpha_new
- result['beta'] = beta_new
- result['mu'] = mu_new
- result['sigma'] = sigma_new
- result['L'] = L
- return result
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement