daily pastebin goal
90%
SHARE
TWEET

Untitled

a guest May 16th, 2018 112 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def line_search(oracle, x_k, t, d_k, alpha_0=1., c1=1e-4):
  2.     """
  3.     Returns
  4.     -------
  5.     alpha : float or None if failure
  6.         Chosen step size
  7.     """
  8.  
  9.     phi = lambda alpha: oracle.f_t_directional(x_k, t, d_k, alpha)
  10.     derphi = lambda alpha: oracle.grad_directional(x_k, t, d_k, alpha)
  11.     phi_0 = phi(0.)
  12.     derphi_0 = derphi(0.)
  13.     alpha = alpha_0
  14.  
  15.     while phi(alpha) > phi_0 + c1 * alpha * derphi_0:
  16.         alpha = alpha / 2.
  17.  
  18.     return alpha
  19.  
  20.  
  21. def newton_barrier(oracle, common_x_0, t, tolerance=1e-5,
  22.                    max_iter=100, theta=0.99, c1=1e-4):
  23.  
  24.     common_x_k = common_x_0
  25.     start_grad_norm = norm(oracle.common_grad(common_x_0, t))
  26.     x_k, u_k = np.split(common_x_k, 2)
  27.     prev_x_k, prev_u_k = x_k, u_k
  28.  
  29.     for iter_num in range(max_iter + 1):
  30.  
  31.         if not np.any(np.isfinite(common_x_k)):
  32.             return prev_x_k, prev_u_k
  33.  
  34.         try:
  35.             common_grad = oracle.common_grad(common_x_k, t)
  36.             common_grad_norm = norm(common_grad)
  37.             if common_grad_norm ** 2 < tolerance * start_grad_norm ** 2:
  38.                 break
  39.  
  40.             common_hess = oracle.common_hess(common_x_k, t)
  41.  
  42.         except ValueError:
  43.             return prev_x_k, prev_u_k
  44.  
  45.         try:
  46.             if sp.issparse(common_hess):
  47.                 common_d_k = spla.spsolve(common_hess, -common_grad)
  48.             else:
  49.                 c, lower = cho_factor(common_hess)
  50.                 common_d_k = cho_solve((c, lower), -common_grad)
  51.         except LinAlgError:
  52.             return prev_x_k, prev_u_k
  53.  
  54.         d_x, d_u = np.split(common_d_k, 2)
  55.         x_k, u_k = np.split(common_x_k, 2)
  56.         prev_x_k = x_k
  57.         prev_u_k = u_k
  58.  
  59.         dx_du = d_x - d_u
  60.         _dx_du = -d_x - d_u
  61.  
  62.         pos_mask = dx_du > 0
  63.         neg_mask = _dx_du > 0
  64.  
  65.         alphas_first = theta * (u_k - x_k)[pos_mask] / dx_du[pos_mask]
  66.         alpha_second = theta * (u_k + x_k)[neg_mask] / _dx_du[neg_mask]
  67.  
  68.         alpha_max_to_use = np.concatenate(([1.], alphas_first, alpha_second))
  69.         alpha_max_to_use = alpha_max_to_use.min()
  70.  
  71.         alpha = line_search(oracle=oracle, x_k=common_x_k, t=t,
  72.                             d_k=common_d_k, alpha_0=alpha_max_to_use, c1=c1)
  73.  
  74.         common_x_k = common_x_k + alpha * common_d_k
  75.  
  76.     x_star, u_star = np.split(common_x_k, 2)
  77.     return x_star, u_star
  78.  
  79.  
  80. def barrier_method_lasso(A, b, reg_coef, x_0, u_0, tolerance=1e-5,
  81.                          tolerance_inner=1e-8, max_iter=100,
  82.                          max_iter_inner=20, t_0=1, gamma=10,
  83.                          c1=1e-4, lasso_duality_gap=None,
  84.                          trace=False, display=False):
  85.     """
  86.     Log-barrier method for solving the problem:
  87.         minimize    f(x, u) := 1/2 * ||Ax - b||_2^2 + reg_coef * \sum_i u_i
  88.         subject to  -u_i <= x_i <= u_i.
  89.  
  90.     The method constructs the following barrier-approximation of the problem:
  91.         phi_t(x, u) := t * f(x, u) - sum_i( log(u_i + x_i) + log(u_i - x_i) )
  92.     and minimize it as unconstrained problem by Newton's method.
  93.  
  94.     In the outer loop `t` is increased and we have a sequence of approximations
  95.         { phi_t(x, u) } and solutions { (x_t, u_t)^{*} } which converges in `t`
  96.     to the solution of the original problem.
  97.  
  98.     Parameters
  99.     ----------
  100.     A : np.array
  101.         Feature matrix for the regression problem.
  102.     b : np.array
  103.         Given vector of responses.
  104.     reg_coef : float
  105.         Regularization coefficient.
  106.     x_0 : np.array
  107.         Starting value for x in optimization algorithm.
  108.     u_0 : np.array
  109.         Starting value for u in optimization algorithm.
  110.     tolerance : float
  111.         Epsilon value for stopping criterion.
  112.     max_iter : int
  113.         Maximum number of iterations for interior point method.
  114.     max_iter_inner : int
  115.         Maximum number of iterations for inner Newton's method.
  116.     t_0 : float
  117.         Starting value for `t`.
  118.     gamma : float
  119.         Multiplier for changing `t` during the iterations:
  120.         t_{k + 1} = gamma * t_k.
  121.     c1 : float
  122.         Armijo's constant for line search in Newton's method.
  123.     lasso_duality_gap : callable object or None.
  124.         If calable the signature is lasso_duality_gap(x, Ax_b, ATAx_b, b, regcoef)
  125.         Returns duality gap value for esimating the progress of method.
  126.     trace : bool
  127.         If True, the progress information is appended into history dictionary
  128.         during training. Otherwise None is returned instead of history.
  129.     display : bool
  130.         If True, debug information is displayed during optimization.
  131.         Printing format is up to a student and is not checked in any way.
  132.  
  133.     Returns
  134.     -------
  135.     (x_star, u_star) : tuple of np.array
  136.         The point found by the optimization procedure.
  137.     message : string
  138.         "success" or the description of error:
  139.             - 'iterations_exceeded': if after max_iter iterations of the method x_k still doesn't satisfy
  140.                 the stopping criterion.
  141.             - 'computational_error': in case of getting Infinity or None value during the computations.
  142.     history : dictionary of lists or None
  143.         Dictionary containing the progress information or None if trace=False.
  144.         Dictionary has to be organized as follows:
  145.             - history['time'] : list of floats, containing time in seconds passed from the start of the method
  146.             - history['func'] : list of function values f(x_k) on every step of the algorithm
  147.             - history['duality_gap'] : list of duality gaps
  148.             - history['x'] : list of np.arrays, containing the trajectory of the algorithm. ONLY STORE IF x.size <= 2
  149.     """
  150.     history = defaultdict(list) if trace else None
  151.     converged = False
  152.     start_time = time()
  153.     x_k = x_0
  154.     u_k = u_0
  155.     t_k = t_0
  156.     theta = 0.99
  157.  
  158.     oracle = oracles.LassoBarrierOracul(A, b, reg_coef)
  159.  
  160.     for iter_num in range(max_iter + 1):
  161.  
  162.         gap = lasso_duality_gap(x_k, oracle.Ax_b(x_k),
  163.                                 oracle.ATAx_b(x_k), b, reg_coef)
  164.  
  165.         f_k = oracle.f(x_k, u_k)
  166.  
  167.         if trace:
  168.             history['time'].append(time() - start_time)
  169.             history['func'].append(f_k)
  170.             history['duality_gap'].append(gap)
  171.             if x_k.size <= 2:
  172.                 history['x'].append(x_k)
  173.  
  174.         if display:
  175.             print("f_k = {}".format(f_k))
  176.  
  177.         if gap < tolerance:
  178.             converged = True
  179.             break
  180.  
  181.         common_x_k = np.concatenate((x_k, u_k), axis=0)
  182.         x_k, u_k = newton_barrier(oracle, common_x_k, t_k,
  183.                                   tolerance=tolerance_inner,
  184.                                   max_iter=max_iter_inner, theta=theta, c1=c1)
  185.         t_k *= gamma
  186.  
  187.     if converged:
  188.         return (x_k, u_k), 'success', history
  189.     else:
  190.         return (x_k, u_k), 'iterations_exceeded', history
RAW Paste Data
Top