You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

482 lines
15 KiB
Python

"""
Copyright (C) 2010 David Fong and Michael Saunders
LSMR uses an iterative method.
07 Jun 2010: Documentation updated
03 Jun 2010: First release version in Python
David Chin-lung Fong clfong@stanford.edu
Institute for Computational and Mathematical Engineering
Stanford University
Michael Saunders saunders@stanford.edu
Systems Optimization Laboratory
Dept of MS&E, Stanford University.
"""
__all__ = ['lsmr']
from numpy import zeros, infty, atleast_1d, result_type
from numpy.linalg import norm
from math import sqrt
from scipy.sparse.linalg.interface import aslinearoperator
from .lsqr import _sym_ortho
def lsmr(A, b, damp=0.0, atol=1e-6, btol=1e-6, conlim=1e8,
maxiter=None, show=False, x0=None):
"""Iterative solver for least-squares problems.
lsmr solves the system of linear equations ``Ax = b``. If the system
is inconsistent, it solves the least-squares problem ``min ||b - Ax||_2``.
``A`` is a rectangular matrix of dimension m-by-n, where all cases are
allowed: m = n, m > n, or m < n. ``b`` is a vector of length m.
The matrix A may be dense or sparse (usually sparse).
Parameters
----------
A : {matrix, sparse matrix, ndarray, LinearOperator}
Matrix A in the linear system.
Alternatively, ``A`` can be a linear operator which can
produce ``Ax`` and ``A^H x`` using, e.g.,
``scipy.sparse.linalg.LinearOperator``.
b : array_like, shape (m,)
Vector ``b`` in the linear system.
damp : float
Damping factor for regularized least-squares. `lsmr` solves
the regularized least-squares problem::
min ||(b) - ( A )x||
||(0) (damp*I) ||_2
where damp is a scalar. If damp is None or 0, the system
is solved without regularization.
atol, btol : float, optional
Stopping tolerances. `lsmr` continues iterations until a
certain backward error estimate is smaller than some quantity
depending on atol and btol. Let ``r = b - Ax`` be the
residual vector for the current approximate solution ``x``.
If ``Ax = b`` seems to be consistent, ``lsmr`` terminates
when ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``.
Otherwise, lsmr terminates when ``norm(A^H r) <=
atol * norm(A) * norm(r)``. If both tolerances are 1.0e-6 (say),
the final ``norm(r)`` should be accurate to about 6
digits. (The final ``x`` will usually have fewer correct digits,
depending on ``cond(A)`` and the size of LAMBDA.) If `atol`
or `btol` is None, a default value of 1.0e-6 will be used.
Ideally, they should be estimates of the relative error in the
entries of ``A`` and ``b`` respectively. For example, if the entries
of ``A`` have 7 correct digits, set ``atol = 1e-7``. This prevents
the algorithm from doing unnecessary work beyond the
uncertainty of the input data.
conlim : float, optional
`lsmr` terminates if an estimate of ``cond(A)`` exceeds
`conlim`. For compatible systems ``Ax = b``, conlim could be
as large as 1.0e+12 (say). For least-squares problems,
`conlim` should be less than 1.0e+8. If `conlim` is None, the
default value is 1e+8. Maximum precision can be obtained by
setting ``atol = btol = conlim = 0``, but the number of
iterations may then be excessive.
maxiter : int, optional
`lsmr` terminates if the number of iterations reaches
`maxiter`. The default is ``maxiter = min(m, n)``. For
ill-conditioned systems, a larger value of `maxiter` may be
needed.
show : bool, optional
Print iterations logs if ``show=True``.
x0 : array_like, shape (n,), optional
Initial guess of ``x``, if None zeros are used.
.. versionadded:: 1.0.0
Returns
-------
x : ndarray of float
Least-square solution returned.
istop : int
istop gives the reason for stopping::
istop = 0 means x=0 is a solution. If x0 was given, then x=x0 is a
solution.
= 1 means x is an approximate solution to A*x = B,
according to atol and btol.
= 2 means x approximately solves the least-squares problem
according to atol.
= 3 means COND(A) seems to be greater than CONLIM.
= 4 is the same as 1 with atol = btol = eps (machine
precision)
= 5 is the same as 2 with atol = eps.
= 6 is the same as 3 with CONLIM = 1/eps.
= 7 means ITN reached maxiter before the other stopping
conditions were satisfied.
itn : int
Number of iterations used.
normr : float
``norm(b-Ax)``
normar : float
``norm(A^H (b - Ax))``
norma : float
``norm(A)``
conda : float
Condition number of A.
normx : float
``norm(x)``
Notes
-----
.. versionadded:: 0.11.0
References
----------
.. [1] D. C.-L. Fong and M. A. Saunders,
"LSMR: An iterative algorithm for sparse least-squares problems",
SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011.
:arxiv:`1006.0758`
.. [2] LSMR Software, https://web.stanford.edu/group/SOL/software/lsmr/
Examples
--------
>>> from scipy.sparse import csc_matrix
>>> from scipy.sparse.linalg import lsmr
>>> A = csc_matrix([[1., 0.], [1., 1.], [0., 1.]], dtype=float)
The first example has the trivial solution `[0, 0]`
>>> b = np.array([0., 0., 0.], dtype=float)
>>> x, istop, itn, normr = lsmr(A, b)[:4]
>>> istop
0
>>> x
array([ 0., 0.])
The stopping code `istop=0` returned indicates that a vector of zeros was
found as a solution. The returned solution `x` indeed contains `[0., 0.]`.
The next example has a non-trivial solution:
>>> b = np.array([1., 0., -1.], dtype=float)
>>> x, istop, itn, normr = lsmr(A, b)[:4]
>>> istop
1
>>> x
array([ 1., -1.])
>>> itn
1
>>> normr
4.440892098500627e-16
As indicated by `istop=1`, `lsmr` found a solution obeying the tolerance
limits. The given solution `[1., -1.]` obviously solves the equation. The
remaining return values include information about the number of iterations
(`itn=1`) and the remaining difference of left and right side of the solved
equation.
The final example demonstrates the behavior in the case where there is no
solution for the equation:
>>> b = np.array([1., 0.01, -1.], dtype=float)
>>> x, istop, itn, normr = lsmr(A, b)[:4]
>>> istop
2
>>> x
array([ 1.00333333, -0.99666667])
>>> A.dot(x)-b
array([ 0.00333333, -0.00333333, 0.00333333])
>>> normr
0.005773502691896255
`istop` indicates that the system is inconsistent and thus `x` is rather an
approximate solution to the corresponding least-squares problem. `normr`
contains the minimal distance that was found.
"""
A = aslinearoperator(A)
b = atleast_1d(b)
if b.ndim > 1:
b = b.squeeze()
msg = ('The exact solution is x = 0, or x = x0, if x0 was given ',
'Ax - b is small enough, given atol, btol ',
'The least-squares solution is good enough, given atol ',
'The estimate of cond(Abar) has exceeded conlim ',
'Ax - b is small enough for this machine ',
'The least-squares solution is good enough for this machine',
'Cond(Abar) seems to be too large for this machine ',
'The iteration limit has been reached ')
hdg1 = ' itn x(1) norm r norm Ar'
hdg2 = ' compatible LS norm A cond A'
pfreq = 20 # print frequency (for repeating the heading)
pcount = 0 # print counter
m, n = A.shape
# stores the num of singular values
minDim = min([m, n])
if maxiter is None:
maxiter = minDim
if x0 is None:
dtype = result_type(A, b, float)
else:
dtype = result_type(A, b, x0, float)
if show:
print(' ')
print('LSMR Least-squares solution of Ax = b\n')
print(f'The matrix A has {m} rows and {n} columns')
print('damp = %20.14e\n' % (damp))
print('atol = %8.2e conlim = %8.2e\n' % (atol, conlim))
print('btol = %8.2e maxiter = %8g\n' % (btol, maxiter))
u = b
normb = norm(b)
if x0 is None:
x = zeros(n, dtype)
beta = normb.copy()
else:
x = atleast_1d(x0)
u = u - A.matvec(x)
beta = norm(u)
if beta > 0:
u = (1 / beta) * u
v = A.rmatvec(u)
alpha = norm(v)
else:
v = zeros(n, dtype)
alpha = 0
if alpha > 0:
v = (1 / alpha) * v
# Initialize variables for 1st iteration.
itn = 0
zetabar = alpha * beta
alphabar = alpha
rho = 1
rhobar = 1
cbar = 1
sbar = 0
h = v.copy()
hbar = zeros(n, dtype)
# Initialize variables for estimation of ||r||.
betadd = beta
betad = 0
rhodold = 1
tautildeold = 0
thetatilde = 0
zeta = 0
d = 0
# Initialize variables for estimation of ||A|| and cond(A)
normA2 = alpha * alpha
maxrbar = 0
minrbar = 1e+100
normA = sqrt(normA2)
condA = 1
normx = 0
# Items for use in stopping rules, normb set earlier
istop = 0
ctol = 0
if conlim > 0:
ctol = 1 / conlim
normr = beta
# Reverse the order here from the original matlab code because
# there was an error on return when arnorm==0
normar = alpha * beta
if normar == 0:
if show:
print(msg[0])
return x, istop, itn, normr, normar, normA, condA, normx
if show:
print(' ')
print(hdg1, hdg2)
test1 = 1
test2 = alpha / beta
str1 = '%6g %12.5e' % (itn, x[0])
str2 = ' %10.3e %10.3e' % (normr, normar)
str3 = ' %8.1e %8.1e' % (test1, test2)
print(''.join([str1, str2, str3]))
# Main iteration loop.
while itn < maxiter:
itn = itn + 1
# Perform the next step of the bidiagonalization to obtain the
# next beta, u, alpha, v. These satisfy the relations
# beta*u = a*v - alpha*u,
# alpha*v = A'*u - beta*v.
u *= -alpha
u += A.matvec(v)
beta = norm(u)
if beta > 0:
u *= (1 / beta)
v *= -beta
v += A.rmatvec(u)
alpha = norm(v)
if alpha > 0:
v *= (1 / alpha)
# At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
# Construct rotation Qhat_{k,2k+1}.
chat, shat, alphahat = _sym_ortho(alphabar, damp)
# Use a plane rotation (Q_i) to turn B_i to R_i
rhoold = rho
c, s, rho = _sym_ortho(alphahat, beta)
thetanew = s*alpha
alphabar = c*alpha
# Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
rhobarold = rhobar
zetaold = zeta
thetabar = sbar * rho
rhotemp = cbar * rho
cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew)
zeta = cbar * zetabar
zetabar = - sbar * zetabar
# Update h, h_hat, x.
hbar *= - (thetabar * rho / (rhoold * rhobarold))
hbar += h
x += (zeta / (rho * rhobar)) * hbar
h *= - (thetanew / rho)
h += v
# Estimate of ||r||.
# Apply rotation Qhat_{k,2k+1}.
betaacute = chat * betadd
betacheck = -shat * betadd
# Apply rotation Q_{k,k+1}.
betahat = c * betaacute
betadd = -s * betaacute
# Apply rotation Qtilde_{k-1}.
# betad = betad_{k-1} here.
thetatildeold = thetatilde
ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar)
thetatilde = stildeold * rhobar
rhodold = ctildeold * rhobar
betad = - stildeold * betad + ctildeold * betahat
# betad = betad_k here.
# rhodold = rhod_k here.
tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold
taud = (zeta - thetatilde * tautildeold) / rhodold
d = d + betacheck * betacheck
normr = sqrt(d + (betad - taud)**2 + betadd * betadd)
# Estimate ||A||.
normA2 = normA2 + beta * beta
normA = sqrt(normA2)
normA2 = normA2 + alpha * alpha
# Estimate cond(A).
maxrbar = max(maxrbar, rhobarold)
if itn > 1:
minrbar = min(minrbar, rhobarold)
condA = max(maxrbar, rhotemp) / min(minrbar, rhotemp)
# Test for convergence.
# Compute norms for convergence testing.
normar = abs(zetabar)
normx = norm(x)
# Now use these norms to estimate certain other quantities,
# some of which will be small near a solution.
test1 = normr / normb
if (normA * normr) != 0:
test2 = normar / (normA * normr)
else:
test2 = infty
test3 = 1 / condA
t1 = test1 / (1 + normA * normx / normb)
rtol = btol + atol * normA * normx / normb
# The following tests guard against extremely small values of
# atol, btol or ctol. (The user may have set any or all of
# the parameters atol, btol, conlim to 0.)
# The effect is equivalent to the normAl tests using
# atol = eps, btol = eps, conlim = 1/eps.
if itn >= maxiter:
istop = 7
if 1 + test3 <= 1:
istop = 6
if 1 + test2 <= 1:
istop = 5
if 1 + t1 <= 1:
istop = 4
# Allow for tolerances set by the user.
if test3 <= ctol:
istop = 3
if test2 <= atol:
istop = 2
if test1 <= rtol:
istop = 1
# See if it is time to print something.
if show:
if (n <= 40) or (itn <= 10) or (itn >= maxiter - 10) or \
(itn % 10 == 0) or (test3 <= 1.1 * ctol) or \
(test2 <= 1.1 * atol) or (test1 <= 1.1 * rtol) or \
(istop != 0):
if pcount >= pfreq:
pcount = 0
print(' ')
print(hdg1, hdg2)
pcount = pcount + 1
str1 = '%6g %12.5e' % (itn, x[0])
str2 = ' %10.3e %10.3e' % (normr, normar)
str3 = ' %8.1e %8.1e' % (test1, test2)
str4 = ' %8.1e %8.1e' % (normA, condA)
print(''.join([str1, str2, str3, str4]))
if istop > 0:
break
# Print the stopping condition.
if show:
print(' ')
print('LSMR finished')
print(msg[istop])
print('istop =%8g normr =%8.1e' % (istop, normr))
print(' normA =%8.1e normAr =%8.1e' % (normA, normar))
print('itn =%8g condA =%8.1e' % (itn, condA))
print(' normx =%8.1e' % (normx))
print(str1, str2)
print(str3, str4)
return x, istop, itn, normr, normar, normA, condA, normx