Guest User

Iterative Reweighted Least Square

a guest
Jan 22nd, 2014
177
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
MatLab 5.03 KB | None | 0 0
  1. function [a, rw, info] = irls(X, y, wf, s, r2, a0, varargin)
  2. % Iterative Reweighted Least Square
  3. %
  4. %   The iterative reweighted least square algorithm updates vector a
  5. %   by alternating between the following two steps:
  6. %
  7. %       1. solve the weighted least square problem to get a
  8. %
  9. %           minimize sum_i w_i || x_i' * a - y_i ||^2 + (r2/2) * ||a||^2
  10. %
  11. %       2. update the weights as
  12. %
  13. %           w_i = wf(||x_i' * a - y_i|| / s)
  14. %
  15. %          Here, wf is a function that calculates a weight based on
  16. %          the residue norm.
  17. %
  18. %   a = irls(X, y, wf, s, r2, a0, ...);
  19. %   [a, rw] = irls(X, y, wf, s, r2, a0, ...);
  20. %   [a, rw, info] = irls(X, y, wf, s, r2, a0, ...);
  21. %
  22. %       performs iterative reweighted least square in solving the
  23. %       coefficient vector/matrix a.
  24. %
  25. %       Input arguments:
  26. %       - X:        the design matrix. In particular, x_i is given by
  27. %                   the i-th row of X, i.e. X(i,:).
  28. %
  29. %       - y:        the response vector/matrix. y(i,:) corresponds to
  30. %                   X(i,:).
  31. %
  32. %       - wf:       the weighting function, which should be able to take
  33. %                   into multiple input residues in form of a matrix, and
  34. %                   return a matrix of the same size.
  35. %
  36. %       - s:        the scale parameter
  37. %
  38. %       - a0:       the initial guess of a.
  39. %
  40. %       Output arguments:
  41. %       - a:        the resultant coefficient vector/matrix.
  42. %
  43. %       - rw:       the weights used in the final iteration (n x 1)
  44. %
  45. %       - info:     a struct that contains the procedural information.
  46. %
  47. %       Suppose X is a matrix of size n x d, and y is a matrix of size
  48. %       n x q, then a (and a0) will be a matrix of size d x q.
  49. %
  50. %       One can specify additional options to control the procedure in
  51. %       name/value pairs.
  52. %
  53. %       - MaxIter:      the maximum number of iterations {100}
  54. %       - TolFun:       the tolerance of objective value change at
  55. %                       convergence {1e-8}
  56. %       - TolX:         the tolerance of change of a at convergence {1e-8}
  57. %       - Display:      the level of information displaying
  58. %                       {'none'|'proc'|'iter'}
  59. %       - Monitor:      the monitor that responses to procedural updates
  60. %
  61.  
  62. %   History
  63. %   -------
  64. %       - Created by Dahua Lin, on Jan 6, 2011
  65. %
  66.  
  67. %% verify input and check options
  68.  
  69. if ~(isfloat(X) && ndims(X) == 2)
  70.     error('irls:invalidarg', 'X should be a numeric matrix.');
  71. end
  72. [n, d] = size(X);
  73.  
  74. if ~(isfloat(y) && ndims(y) == 2 && size(y, 1) == n)
  75.     error('irls:invalidarg', 'y should be a numeric matrix with n rows.');
  76. end
  77. q = size(y, 2);
  78.  
  79. if ~isa(wf, 'function_handle')
  80.     error('irls:invalidarg', 'wf should be a function handle.');
  81. end
  82.  
  83. if ~(isfloat(s) && isscalar(s) && s > 0)
  84.     error('irls:invalidarg', 's should be a positive scalar.');
  85. end
  86.  
  87. if ~(isfloat(r2) && isscalar(r2) && r2 >= 0)
  88.     error('irls:invalidarg', 'r2 should be a non-negative scalar.');
  89. end
  90.  
  91. if ~(isfloat(a0) && isequal(size(a0), [d q]))
  92.     error('irls:invalidarg', 'a0 should be a numeric matrix of size d x q.');
  93. end
  94.  
  95. if numel(varargin) == 1 && isstruct(varargin{1})
  96.     options = varargin{1};
  97. else
  98.     options = struct('MaxIter', 100, 'TolFun', 1e-8, 'TolX', 1e-8);
  99.  
  100.     if nargin == 1 && strcmpi(f, 'options')
  101.         a = options;
  102.         return;
  103.     end
  104.     options = smi_optimset(options, varargin{:});
  105. end
  106.  
  107. omon_level = 0;
  108. if isfield(options, 'Monitor')
  109.     omon = options.Monitor;
  110.     omon_level = omon.level;
  111. end
  112.  
  113.  
  114. %% main
  115.  
  116. a = a0;
  117. converged = false;
  118. it = 0;
  119.  
  120. if omon_level >= optim_mon.ProcLevel
  121.     omon.on_proc_start();
  122. end
  123.  
  124. % initial weighting
  125.  
  126. e = X * a - y;
  127. [rn, rn2] = e_to_rn(e);
  128. rw = wf(rn);
  129. v = (rw' * rn2) / 2;
  130.  
  131.  
  132. while ~converged && it < options.MaxIter
  133.    
  134.     it = it + 1;
  135.     if omon_level >= optim_mon.IterLevel
  136.         omon.on_iter_start(it);
  137.     end
  138.    
  139.     a_p = a;
  140.     v_p = v;
  141.    
  142.     % re-solve a
  143.     a = llsq(X, y, rw, r2);
  144.    
  145.     % re-evaluate rw and v
  146.     e = (1/s) * (X * a - y);
  147.     [rn, rn2] = e_to_rn(e);
  148.     rw = wf(rn);
  149.     v = (rw' * rn2) / 2;
  150.        
  151.     % determine convergence
  152.     ch = v - v_p;
  153.     nrm_da = norm(a - a_p);
  154.     converged = abs(ch) < options.TolFun && nrm_da < options.TolX;  
  155.        
  156.    
  157.     if omon_level >= optim_mon.IterLevel        
  158.         itstat = struct( ...
  159.             'FunValue', v, ...
  160.             'FunChange', ch, ...
  161.             'Move', NaN, ...
  162.             'MoveNorm', nrm_da, ...
  163.             'IsConverged', converged);                    
  164.         omon.on_iter_end(it, itstat);
  165.     end
  166. end
  167.  
  168.  
  169. if  nargout >= 2 || omon_level >= optim_mon.ProcLevel
  170.     info = struct( ...
  171.         'FunValue', v, ...
  172.         'LastChange', ch, ...
  173.         'LastMove', nrm_da, ...
  174.         'IsConverged', converged, ...
  175.         'NumIters', it);
  176. end
  177.  
  178. if omon_level >= optim_mon.ProcLevel
  179.     omon.on_proc_end(info);
  180. end
  181.  
  182.  
  183. function [rn, rn2] = e_to_rn(e)
  184.  
  185. if size(e, 2) == 1
  186.     rn = abs(e);
  187.     rn2 = e .^ 2;
  188. else
  189.     rn2 = dot(e, e, 2);
  190.     rn = sqrt(rn2);
  191. end
Add Comment
Please, Sign In to add comment