Don't like ads? PRO users don't see any ads ;-)
Guest

Untitled

By: a guest on May 22nd, 2012  |  syntax: None  |  size: 3.19 KB  |  hits: 16  |  expires: Never
download  |  raw  |  embed  |  report abuse  |  print
Text below is selected. Please press Ctrl+C to copy to your clipboard. (⌘+C on Mac)
  1. """ Module to compute projections on the positive simplex or the L1-ball
  2.  
  3. A positive simplex is a set X = { \mathbf{x} | \sum_i x_i = s, x_i \geq 0 }
  4.  
  5. The (unit) L1-ball is the set X = { \mathbf{x} | || x ||_1 \leq 1 }
  6.  
  7. Adrien Gaidon - INRIA - 2011
  8. """
  9.  
  10.  
  11. import numpy as np
  12.  
  13.  
  14. def euclidean_proj_simplex(v, s=1):
  15.     """ Compute the Euclidean projection on a positive simplex
  16.  
  17.     Solves the optimisation problem (using the algorithm from [1]):
  18.  
  19.         min_w 0.5 * || w - v ||_2^2 , s.t. \sum_i w_i = s, w_i >= 0
  20.  
  21.     Parameters
  22.     ----------
  23.     v: (n,) numpy array,
  24.        n-dimensional vector to project
  25.  
  26.     s: int, optional, default: 1,
  27.        radius of the simplex
  28.  
  29.     Returns
  30.     -------
  31.     w: (n,) numpy array,
  32.        Euclidean projection of v on the simplex
  33.  
  34.     Notes
  35.     -----
  36.     The complexity of this algorithm is in O(n log(n)) as it involves sorting v.
  37.     Better alternatives exist for high-dimensional sparse vectors (cf. [1])
  38.     However, this implementation still easily scales to millions of dimensions.
  39.  
  40.     References
  41.     ----------
  42.     [1] Efficient Projections onto the .1-Ball for Learning in High Dimensions
  43.         John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra.
  44.         International Conference on Machine Learning (ICML 2008)
  45.         http://www.cs.berkeley.edu/~jduchi/projects/DuchiSiShCh08.pdf
  46.     """
  47.     assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
  48.     n, = v.shape  # will raise ValueError if v is not 1-D
  49.     # check if we are already on the simplex
  50.     if v.sum() == s and np.alltrue(v >= 0):
  51.         # best projection: itself!
  52.         return v
  53.     # get the array of cumulative sums of a sorted (decreasing) copy of v
  54.     u = np.sort(v)[::-1]
  55.     cssv = np.cumsum(u)
  56.     # get the number of > 0 components of the optimal solution
  57.     rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1]
  58.     # compute the Lagrange multiplier associated to the simplex constraint
  59.     theta = (cssv[rho] - s) / (rho + 1.0)
  60.     # compute the projection by thresholding v using theta
  61.     w = (v - theta).clip(min=0)
  62.     return w
  63.  
  64.  
  65. def euclidean_proj_l1ball(v, s=1):
  66.     """ Compute the Euclidean projection on a L1-ball
  67.  
  68.     Solves the optimisation problem (using the algorithm from [1]):
  69.  
  70.         min_w 0.5 * || w - v ||_2^2 , s.t. || w ||_1 <= s
  71.  
  72.     Parameters
  73.     ----------
  74.     v: (n,) numpy array,
  75.        n-dimensional vector to project
  76.  
  77.     s: int, optional, default: 1,
  78.        radius of the L1-ball
  79.  
  80.     Returns
  81.     -------
  82.     w: (n,) numpy array,
  83.        Euclidean projection of v on the L1-ball of radius s
  84.  
  85.     Notes
  86.     -----
  87.     Solves the problem by a reduction to the positive simplex case
  88.  
  89.     See also
  90.     --------
  91.     euclidean_proj_simplex
  92.     """
  93.     assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
  94.     n, = v.shape  # will raise ValueError if v is not 1-D
  95.     # compute the vector of absolute values
  96.     u = np.abs(v)
  97.     # check if v is already a solution
  98.     if u.sum() <= s:
  99.         # L1-norm is <= s
  100.         return v
  101.     # v is not already a solution: optimum lies on the boundary (norm == s)
  102.     # project *u* on the simplex
  103.     w = euclidean_proj_simplex(u, s=s)
  104.     # compute the solution to the original problem on v
  105.     w *= np.sign(v)
  106.     return w