Advertisement
Guest User

Untitled

a guest
Oct 17th, 2019
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.33 KB | None | 0 0
  1. # Linear Least Squares
  2. def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
  3. check_finite=True, lapack_driver=None):
  4. """
  5. Compute least-squares solution to equation Ax = b.
  6.  
  7. Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
  8.  
  9. Parameters
  10. ----------
  11. a : (M, N) array_like
  12. Left hand side array
  13. b : (M,) or (M, K) array_like
  14. Right hand side array
  15. cond : float, optional
  16. Cutoff for 'small' singular values; used to determine effective
  17. rank of a. Singular values smaller than
  18. ``rcond * largest_singular_value`` are considered zero.
  19. overwrite_a : bool, optional
  20. Discard data in `a` (may enhance performance). Default is False.
  21. overwrite_b : bool, optional
  22. Discard data in `b` (may enhance performance). Default is False.
  23. check_finite : bool, optional
  24. Whether to check that the input matrices contain only finite numbers.
  25. Disabling may give a performance gain, but may result in problems
  26. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  27. lapack_driver : str, optional
  28. Which LAPACK driver is used to solve the least-squares problem.
  29. Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default
  30. (``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly
  31. faster on many problems. ``'gelss'`` was used historically. It is
  32. generally slow but uses less memory.
  33.  
  34. .. versionadded:: 0.17.0
  35.  
  36. Returns
  37. -------
  38. x : (N,) or (N, K) ndarray
  39. Least-squares solution. Return shape matches shape of `b`.
  40. residues : (K,) ndarray or float
  41. Square of the 2-norm for each column in ``b - a x``, if ``M > N`` and
  42. ``ndim(A) == n`` (returns a scalar if b is 1-D). Otherwise a
  43. (0,)-shaped array is returned.
  44. rank : int
  45. Effective rank of `a`.
  46. s : (min(M, N),) ndarray or None
  47. Singular values of `a`. The condition number of a is
  48. ``abs(s[0] / s[-1])``.
  49.  
  50. Raises
  51. ------
  52. LinAlgError
  53. If computation does not converge.
  54.  
  55. ValueError
  56. When parameters are not compatible.
  57.  
  58. See Also
  59. --------
  60. scipy.optimize.nnls : linear least squares with non-negativity constraint
  61.  
  62. Notes
  63. -----
  64. When ``'gelsy'`` is used as a driver, `residues` is set to a (0,)-shaped
  65. array and `s` is always ``None``.
  66.  
  67. Examples
  68. --------
  69. >>> from scipy.linalg import lstsq
  70. >>> import matplotlib.pyplot as plt
  71.  
  72. Suppose we have the following data:
  73.  
  74. >>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5])
  75. >>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
  76.  
  77. We want to fit a quadratic polynomial of the form ``y = a + b*x**2``
  78. to this data. We first form the "design matrix" M, with a constant
  79. column of 1s and a column containing ``x**2``:
  80.  
  81. >>> M = x[:, np.newaxis]**[0, 2]
  82. >>> M
  83. array([[ 1. , 1. ],
  84. [ 1. , 6.25],
  85. [ 1. , 12.25],
  86. [ 1. , 16. ],
  87. [ 1. , 25. ],
  88. [ 1. , 49. ],
  89. [ 1. , 72.25]])
  90.  
  91. We want to find the least-squares solution to ``M.dot(p) = y``,
  92. where ``p`` is a vector with length 2 that holds the parameters
  93. ``a`` and ``b``.
  94.  
  95. >>> p, res, rnk, s = lstsq(M, y)
  96. >>> p
  97. array([ 0.20925829, 0.12013861])
  98.  
  99. Plot the data and the fitted curve.
  100.  
  101. >>> plt.plot(x, y, 'o', label='data')
  102. >>> xx = np.linspace(0, 9, 101)
  103. >>> yy = p[0] + p[1]*xx**2
  104. >>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$')
  105. >>> plt.xlabel('x')
  106. >>> plt.ylabel('y')
  107. >>> plt.legend(framealpha=1, shadow=True)
  108. >>> plt.grid(alpha=0.25)
  109. >>> plt.show()
  110.  
  111. """
  112. a1 = _asarray_validated(a, check_finite=check_finite)
  113. b1 = _asarray_validated(b, check_finite=check_finite)
  114. if len(a1.shape) != 2:
  115. raise ValueError('Input array a should be 2-D')
  116. m, n = a1.shape
  117. if len(b1.shape) == 2:
  118. nrhs = b1.shape[1]
  119. else:
  120. nrhs = 1
  121. if m != b1.shape[0]:
  122. raise ValueError('Shape mismatch: a and b should have the same number'
  123. ' of rows ({} != {}).'.format(m, b1.shape[0]))
  124. if m == 0 or n == 0: # Zero-sized problem, confuses LAPACK
  125. x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
  126. if n == 0:
  127. residues = np.linalg.norm(b1, axis=0)**2
  128. else:
  129. residues = np.empty((0,))
  130. return x, residues, 0, np.empty((0,))
  131.  
  132. driver = lapack_driver
  133. if driver is None:
  134. driver = lstsq.default_lapack_driver
  135. if driver not in ('gelsd', 'gelsy', 'gelss'):
  136. raise ValueError('LAPACK driver "%s" is not found' % driver)
  137.  
  138. lapack_func, lapack_lwork = get_lapack_funcs((driver,
  139. '%s_lwork' % driver),
  140. (a1, b1))
  141. real_data = True if (lapack_func.dtype.kind == 'f') else False
  142.  
  143. if m < n:
  144. # need to extend b matrix as it will be filled with
  145. # a larger solution matrix
  146. if len(b1.shape) == 2:
  147. b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
  148. b2[:m, :] = b1
  149. else:
  150. b2 = np.zeros(n, dtype=lapack_func.dtype)
  151. b2[:m] = b1
  152. b1 = b2
  153.  
  154. overwrite_a = overwrite_a or _datacopied(a1, a)
  155. overwrite_b = overwrite_b or _datacopied(b1, b)
  156.  
  157. if cond is None:
  158. cond = np.finfo(lapack_func.dtype).eps
  159.  
  160. if driver in ('gelss', 'gelsd'):
  161. if driver == 'gelss':
  162. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  163. v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
  164. overwrite_a=overwrite_a,
  165. overwrite_b=overwrite_b)
  166.  
  167. elif driver == 'gelsd':
  168. if real_data:
  169. lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  170. x, s, rank, info = lapack_func(a1, b1, lwork,
  171. iwork, cond, False, False)
  172. else: # complex data
  173. lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
  174. nrhs, cond)
  175. x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
  176. cond, False, False)
  177. if info > 0:
  178. raise LinAlgError("SVD did not converge in Linear Least Squares")
  179. if info < 0:
  180. raise ValueError('illegal value in %d-th argument of internal %s'
  181. % (-info, lapack_driver))
  182. resids = np.asarray([], dtype=x.dtype)
  183. if m > n:
  184. x1 = x[:n]
  185. if rank == n:
  186. resids = np.sum(np.abs(x[n:])**2, axis=0)
  187. x = x1
  188. return x, resids, rank, s
  189.  
  190. elif driver == 'gelsy':
  191. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  192. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  193. v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
  194. lwork, False, False)
  195. if info < 0:
  196. raise ValueError("illegal value in %d-th argument of internal "
  197. "gelsy" % -info)
  198. if m > n:
  199. x1 = x[:n]
  200. x = x1
  201. return x, np.array([], x.dtype), rank, None
  202.  
  203.  
  204. lstsq.default_lapack_driver = 'gelsd'
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement