Tranvick

pd_ipm

May 20th, 2015
307
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.22 KB | None | 0 0
  1. def pd_ipm(X, t, reg_coef, max_iter=100, tol_feas=1e-10, tol_gap=1e-6, tau_param=10, bt_params=np.array([1e-4,0.8]), display=False, bt_max_iter=1000, callback=None):
  2.     def calc_r(lambdas, mu, v1, v2):
  3.         return [
  4.             t * X.dot(X.T.dot(t * lambdas)) - 1. + mu * t - v1 + v2,
  5.             lambdas.dot(t),
  6.             -1. / tau + v1 * lambdas,
  7.             1. / tau + v2 * (lambdas - C)
  8.         ]
  9.    
  10.     def update_alpha(alpha):
  11.         return lambdas + alpha * d_lambda, mu + alpha * d_mu, v1 + alpha * d_v1, v2 + alpha * d_v2        
  12.    
  13.     def backtrack(lambdas, mu, v1, v2, r_dual, r_primal, r_center1, r_center2):
  14.         alpha = 1.
  15.         if np.any(d_lambda > 0):
  16.             alpha = min(alpha, min((C - lambdas[d_lambda > 0]) / d_lambda[d_lambda > 0]))
  17.         if np.any(d_lambda < 0):
  18.             alpha = min(alpha, min(-lambdas[d_lambda < 0] / d_lambda[d_lambda < 0]))
  19.         if np.any(d_v1 < 0):
  20.             alpha = min(alpha, min(-v1[d_v1 < 0] / d_v1[d_v1 < 0]))
  21.         if np.any(d_v2 < 0):
  22.             alpha = min(alpha, min(-v2[d_v2 < 0] / d_v2[d_v2 < 0]))
  23.         alpha *= 0.95
  24.         r_norm = max(norm(r_dual, np.inf), abs(r_primal), norm(r_center1, np.inf), norm(r_center2, np.inf))
  25.         i = 0
  26.         while i < bt_max_iter:
  27.             i += 1
  28.             lambdas_, mu_, v1_, v2_ = update_alpha(alpha)
  29.             r_dual_, r_primal_, r_center1_, r_center2_ = calc_r(lambdas_, mu_, v1_, v2_)
  30.            
  31.             r_norm_ = max(norm(r_dual_, np.inf), abs(r_primal_), norm(r_center1_, np.inf), norm(r_center2_, np.inf))
  32.             if r_norm_ < (1. - alpha * c1) * r_norm:
  33.                 return alpha
  34.             alpha *= beta
  35.         return alpha if r_norm_ < r_norm else None
  36.                
  37.     n, d = X.shape
  38.     C = reg_coef / n
  39.     c1, beta = bt_params
  40.    
  41.     lambdas = np.zeros(n)
  42.     lambdas[t < 0] = 1. / sum(t < 0)
  43.     lambdas[t > 0] = 1. / sum(t > 0)
  44.     lambdas *= C / max(lambdas) / 2.
  45.     v1 = 1. / lambdas
  46.     v2 = 1. / (C - lambdas)
  47.     mu = 0.
  48.     if display:
  49.         print ('%9s %15s %15s %15s %15s' % ('iteration', 'F', '||r_dual||', '|r_primal|', 'lambda^Tg'))
  50.    
  51.     iteration = 0
  52.     tau = 1.
  53.     while iteration < max_iter:
  54.         iteration += 1
  55.         tau = 2. * n * min(1. / tol_gap, tau_param / (v1.dot(lambdas) + v2.dot(C - lambdas)))
  56.         r_dual, r_primal, r_center1, r_center2 = calc_r(lambdas, mu, v1, v2)  
  57.         if display:
  58.             f = -np.sum(lambdas) + (lambdas * t).dot(X.dot(X.T.dot(lambdas * t))) / 2.
  59.             print ('%9d %15e %15e %15e %15e' % (iteration, f, norm(r_dual, np.inf), abs(r_primal), v1.dot(lambdas) + v2.dot(C - lambdas)))
  60.         if callback is not None:
  61.             callback(-np.sum(lambdas) + (lambdas * t).dot(X.dot(X.T.dot(lambdas * t))) / 2.)
  62.         if norm(r_dual, np.inf) <= tol_feas and abs(r_primal) <= tol_feas and v1.dot(lambdas) + v2.dot(C - lambdas) <= tol_gap:
  63.             break
  64.            
  65.         Ainv = 1. / (v1 / lambdas - v2 / (lambdas - C))
  66.         A = (X.T * t * Ainv * t).dot(X) + np.identity(d)
  67.  
  68.         y = -(r_dual + r_center1 / lambdas - r_center2 / (lambdas - C))
  69.         y0 = r_primal
  70.  
  71.         y1 = (X.T).dot(t * Ainv * y)
  72.         y2 = np.array(solve(A, y1)).reshape(-1)
  73.         Pinvy = Ainv * (y - t * X.dot(y2))
  74.  
  75.         t1 = (X.T).dot(t * Ainv * t)
  76.         t2 = np.array(solve(A, t1)).reshape(-1)
  77.         Pinvt = Ainv * (t - t * X.dot(t2))
  78.  
  79.         d_lambda = Pinvy - Pinvt * (t.dot(Pinvy) - y0) / t.dot(Pinvt)
  80.         d_mu = (t.dot(Pinvy) - y0) / t.dot(Pinvt)
  81.  
  82.         d_v1 = -v1 * d_lambda * 1. / lambdas + (-v1 * lambdas + 1. / tau) / lambdas
  83.         d_v2 = -v2 * d_lambda * 1. / (lambdas - C) - (v2 * (lambdas - C) + 1. / tau) / (lambdas - C)
  84.         alpha = backtrack(lambdas, mu, v1, v2, r_dual, r_primal, r_center1, r_center2)
  85.         if alpha is None:
  86.             if display:
  87.                 print("Bactracking did not found any suitable step")
  88.             break
  89.         lambdas, mu, v1, v2 = update_alpha(alpha)
  90.        
  91.     eps = tol_feas
  92.     w = X.T.dot(t * lambdas)
  93.     b = (sum([t_n - w.dot(x_n) for t_n, l_n, x_n in zip(t, lambdas, X) if l_n > eps and l_n < C - eps]) /
  94.          sum((lambdas > eps) & (lambdas < C - eps)))
  95.     return {'w': w, 'b': b, 'lambda': lambdas}
Advertisement
Add Comment
Please, Sign In to add comment