Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Linear Least Squares
- def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
- check_finite=True, lapack_driver=None):
- """
- Compute least-squares solution to equation Ax = b.
- Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
- Parameters
- ----------
- a : (M, N) array_like
- Left hand side array
- b : (M,) or (M, K) array_like
- Right hand side array
- cond : float, optional
- Cutoff for 'small' singular values; used to determine effective
- rank of a. Singular values smaller than
- ``rcond * largest_singular_value`` are considered zero.
- overwrite_a : bool, optional
- Discard data in `a` (may enhance performance). Default is False.
- overwrite_b : bool, optional
- Discard data in `b` (may enhance performance). Default is False.
- check_finite : bool, optional
- Whether to check that the input matrices contain only finite numbers.
- Disabling may give a performance gain, but may result in problems
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
- lapack_driver : str, optional
- Which LAPACK driver is used to solve the least-squares problem.
- Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default
- (``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly
- faster on many problems. ``'gelss'`` was used historically. It is
- generally slow but uses less memory.
- .. versionadded:: 0.17.0
- Returns
- -------
- x : (N,) or (N, K) ndarray
- Least-squares solution. Return shape matches shape of `b`.
- residues : (K,) ndarray or float
- Square of the 2-norm for each column in ``b - a x``, if ``M > N`` and
- ``ndim(A) == n`` (returns a scalar if b is 1-D). Otherwise a
- (0,)-shaped array is returned.
- rank : int
- Effective rank of `a`.
- s : (min(M, N),) ndarray or None
- Singular values of `a`. The condition number of a is
- ``abs(s[0] / s[-1])``.
- Raises
- ------
- LinAlgError
- If computation does not converge.
- ValueError
- When parameters are not compatible.
- See Also
- --------
- scipy.optimize.nnls : linear least squares with non-negativity constraint
- Notes
- -----
- When ``'gelsy'`` is used as a driver, `residues` is set to a (0,)-shaped
- array and `s` is always ``None``.
- Examples
- --------
- >>> from scipy.linalg import lstsq
- >>> import matplotlib.pyplot as plt
- Suppose we have the following data:
- >>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5])
- >>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
- We want to fit a quadratic polynomial of the form ``y = a + b*x**2``
- to this data. We first form the "design matrix" M, with a constant
- column of 1s and a column containing ``x**2``:
- >>> M = x[:, np.newaxis]**[0, 2]
- >>> M
- array([[ 1. , 1. ],
- [ 1. , 6.25],
- [ 1. , 12.25],
- [ 1. , 16. ],
- [ 1. , 25. ],
- [ 1. , 49. ],
- [ 1. , 72.25]])
- We want to find the least-squares solution to ``M.dot(p) = y``,
- where ``p`` is a vector with length 2 that holds the parameters
- ``a`` and ``b``.
- >>> p, res, rnk, s = lstsq(M, y)
- >>> p
- array([ 0.20925829, 0.12013861])
- Plot the data and the fitted curve.
- >>> plt.plot(x, y, 'o', label='data')
- >>> xx = np.linspace(0, 9, 101)
- >>> yy = p[0] + p[1]*xx**2
- >>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$')
- >>> plt.xlabel('x')
- >>> plt.ylabel('y')
- >>> plt.legend(framealpha=1, shadow=True)
- >>> plt.grid(alpha=0.25)
- >>> plt.show()
- """
- a1 = _asarray_validated(a, check_finite=check_finite)
- b1 = _asarray_validated(b, check_finite=check_finite)
- if len(a1.shape) != 2:
- raise ValueError('Input array a should be 2-D')
- m, n = a1.shape
- if len(b1.shape) == 2:
- nrhs = b1.shape[1]
- else:
- nrhs = 1
- if m != b1.shape[0]:
- raise ValueError('Shape mismatch: a and b should have the same number'
- ' of rows ({} != {}).'.format(m, b1.shape[0]))
- if m == 0 or n == 0: # Zero-sized problem, confuses LAPACK
- x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
- if n == 0:
- residues = np.linalg.norm(b1, axis=0)**2
- else:
- residues = np.empty((0,))
- return x, residues, 0, np.empty((0,))
- driver = lapack_driver
- if driver is None:
- driver = lstsq.default_lapack_driver
- if driver not in ('gelsd', 'gelsy', 'gelss'):
- raise ValueError('LAPACK driver "%s" is not found' % driver)
- lapack_func, lapack_lwork = get_lapack_funcs((driver,
- '%s_lwork' % driver),
- (a1, b1))
- real_data = True if (lapack_func.dtype.kind == 'f') else False
- if m < n:
- # need to extend b matrix as it will be filled with
- # a larger solution matrix
- if len(b1.shape) == 2:
- b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
- b2[:m, :] = b1
- else:
- b2 = np.zeros(n, dtype=lapack_func.dtype)
- b2[:m] = b1
- b1 = b2
- overwrite_a = overwrite_a or _datacopied(a1, a)
- overwrite_b = overwrite_b or _datacopied(b1, b)
- if cond is None:
- cond = np.finfo(lapack_func.dtype).eps
- if driver in ('gelss', 'gelsd'):
- if driver == 'gelss':
- lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
- v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
- overwrite_a=overwrite_a,
- overwrite_b=overwrite_b)
- elif driver == 'gelsd':
- if real_data:
- lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
- x, s, rank, info = lapack_func(a1, b1, lwork,
- iwork, cond, False, False)
- else: # complex data
- lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
- nrhs, cond)
- x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
- cond, False, False)
- if info > 0:
- raise LinAlgError("SVD did not converge in Linear Least Squares")
- if info < 0:
- raise ValueError('illegal value in %d-th argument of internal %s'
- % (-info, lapack_driver))
- resids = np.asarray([], dtype=x.dtype)
- if m > n:
- x1 = x[:n]
- if rank == n:
- resids = np.sum(np.abs(x[n:])**2, axis=0)
- x = x1
- return x, resids, rank, s
- elif driver == 'gelsy':
- lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
- jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
- v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
- lwork, False, False)
- if info < 0:
- raise ValueError("illegal value in %d-th argument of internal "
- "gelsy" % -info)
- if m > n:
- x1 = x[:n]
- x = x1
- return x, np.array([], x.dtype), rank, None
- lstsq.default_lapack_driver = 'gelsd'
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement