Solvers¤
If you're not sure what to use, then pick lineax.AutoLinearSolver and it will automatically dispatch to an efficient solver depending on what structure your linear operator is declared to exhibit. (See the tags page.)
lineax.AbstractLinearSolver
lineax.AbstractLinearSolver
¤
Abstract base class for all linear solvers.
init(operator: lineax.AbstractLinearOperator, options: dict[str, Any]) -> ~_SolverState
¤
Do any initial computation on just the operator.
For example, an LU solver would compute the LU decomposition of the operator (and this does not require knowing the vector yet).
It is common to need to solve the linear system Ax=b multiple times in
succession, with the same operator A and multiple vectors b. This method
improves efficiency by making it possible to re-use the computation performed
on just the operator.
Example
operator = lx.MatrixLinearOperator(...)
vector1 = ...
vector2 = ...
solver = lx.LU()
state = solver.init(operator, options={})
solution1 = lx.linear_solve(operator, vector1, solver, state=state)
solution2 = lx.linear_solve(operator, vector2, solver, state=state)
Arguments:
operator: a linear operator.options: a dictionary of any extra options that the solver may wish to accept.
Returns:
A PyTree of arbitrary Python objects.
compute(state: ~_SolverState, vector: PyTree[Array], options: dict[str, Any]) -> tuple[PyTree[Array], lineax.RESULTS, dict[str, Any]]
¤
Solves a linear system.
Arguments:
state: as returned fromlineax.AbstractLinearSolver.init.vector: the vector to solve against.options: a dictionary of any extra options that the solver may wish to accept. For example,lineax.CGaccepts apreconditioneroption.
Returns:
A 3-tuple of:
- The solution to the linear system.
- An integer indicating the success or failure of the solve. This is an integer
which may be converted to a human-readable error message via
lx.RESULTS[...]. - A dictionary of an extra statistics about the solve, e.g. the number of steps taken.
transpose(state: ~_SolverState, options: dict[str, Any]) -> tuple[~_SolverState, dict[str, Any]]
¤
Transposes the result of lineax.AbstractLinearSolver.init.
That is, it should be the case that
state_transpose, _ = solver.transpose(solver.init(operator, options), options)
state_transpose2 = solver.init(operator.T, options)
It is relatively common (in particular when differentiating through a linear
solve) to need to solve both Ax = b and A^T x = b. This method makes it
possible to avoid computing both solver.init(operator) and
solver.init(operator.T) if one can be cheaply computed from the other.
Arguments:
state: as returned fromsolver.init.options: any extra options that were passed tosolve.init.
Returns:
A 2-tuple of:
- The state of the transposed operator.
- The options for the transposed operator.
conj(state: ~_SolverState, options: dict[str, Any]) -> tuple[~_SolverState, dict[str, Any]]
¤
Conjugate the result of lineax.AbstractLinearSolver.init.
That is, it should be the case that
state_conj, _ = solver.conj(solver.init(operator, options), options)
state_conj2 = solver.init(conj(operator), options)
Arguments:
state: as returned fromsolver.init.options: any extra options that were passed tosolve.init.
Returns:
A 2-tuple of:
- The state of the conjugated operator.
- The options for the conjugated operator.
assume_full_rank() -> bool
¤
Does this solver assume that all operators are full rank?
When False, a more expensive backward pass is needed to account for
the extra generality. In a custom linear solver, it is always safe to
return False.
Arguments:
Nothing.
Returns:
Either True or False.
lineax.AutoLinearSolver(lineax.AbstractLinearSolver)
¤
Automatically determines a good linear solver based on the structure of the operator.
- If
well_posed=True:- If the operator is diagonal, then use
lineax.Diagonal. - If the operator is tridiagonal, then use
lineax.Tridiagonal. - If the operator is triangular, then use
lineax.Triangular. - If the matrix is positive or negative (semi-)definite, then use
lineax.Cholesky. - Else use
lineax.LU.
- If the operator is diagonal, then use
This is a good choice if you want to be certain that an error is raised for ill-posed systems.
- If
well_posed=False:- If the operator is diagonal, then use
lineax.Diagonal. - Else use
lineax.SVD.
- If the operator is diagonal, then use
This is a good choice if you want to be certain that you can handle ill-posed systems.
- If
well_posed=None:- If the operator is non-square, then use
lineax.QR. - If the operator is diagonal, then use
lineax.Diagonal. - If the operator is tridiagonal, then use
lineax.Tridiagonal. - If the operator is triangular, then use
lineax.Triangular. - If the matrix is positive or negative (semi-)definite, then use
lineax.Cholesky. - Else, use
lineax.LU.
- If the operator is non-square, then use
This is a good choice if your primary concern is computational efficiency. It will handle ill-posed systems as long as it is not computationally expensive to do so.
__init__(well_posed: bool | None)
¤
Arguments:
well_posed: whether to only handle well-posed systems or not, as discussed above.
select_solver(operator: lineax.AbstractLinearOperator) -> lineax.AbstractLinearSolver
¤
Check which solver that lineax.AutoLinearSolver will dispatch to.
Arguments:
operator: a linear operator.
Returns:
The linear solver that will be used.
lineax.LU(lineax.AbstractLinearSolver)
¤
LU solver for linear systems.
This solver can only handle square nonsingular operators.
__init__()
¤
Arguments:
Nothing.
Least squares solvers¤
These are capable of solving ill-posed linear problems.
lineax.QR(lineax.AbstractLinearSolver)
¤
QR solver for linear systems.
This solver can handle non-square operators.
This is usually the preferred solver when dealing with non-square operators.
Info
Note that whilst this does handle non-square operators, it still can only handle full-rank operators.
This is because JAX does not currently support a rank-revealing/pivoted QR decomposition, see issue #12897.
For such use cases, switch to lineax.SVD instead.
__init__()
¤
Arguments:
Nothing.
lineax.SVD(lineax.AbstractLinearSolver)
¤
SVD solver for linear systems.
This solver can handle any operator, even nonsquare or singular ones. In these cases it will return the pseudoinverse solution to the linear system.
Equivalent to scipy.linalg.lstsq.
__init__(rcond: float | None = None)
¤
Arguments:
rcond: the cutoff for handling zero entries on the diagonal. Defaults to machine precision timesmax(N, M), where(N, M)is the shape of the operator. (I.e.Nis the output size andMis the input size.)
lineax.Normal(lineax.AbstractLinearSolver)
¤
Wrapper for an inner solver of positive (semi)definite systems. The wrapped solver handles possibly nonsquare systems \(Ax = b\) by applying the inner solver to the normal equations
\(A^* A x = A^* b\)
if \(m \ge n\), otherwise
\(A A^* y = b\),
where \(x = A^* y\).
If the inner solver solves systems with positive definite \(A\), the wrapped solver solves systems with full rank \(A\).
If the inner solver solves systems with positive semidefinite \(A\), the wrapped solver solves systems with arbitrary, possibly rank deficient, \(A\).
Note that this squares the condition number, so applying this method to an iterative inner solver may result in slow convergence and high sensitivity to roundoff error. In this case it may be advantageous to choose an appropriate preconditioner or initial solution guess for the problem.
This wrapper adjusts the following options before passing to the inner
operator (as passed to lx.linear_solve(..., options=...)).
-
preconditioner: Alineax.AbstractLinearOperatorto be used as preconditioner. Defaults tolineax.IdentityLinearOperator. This should be an approximation of the (pseudo)inverse of \(A\). When passed to the inner solver, the preconditioner \(M\) is replaced by \(M M^*\) and \(M^* M\) in the first and second versions of the normal equations, respectively. -
y0: An initial estimate of the solution of the linear system \(Ax = b\). Defaults to all zeros. In the second version of the normal equations, \(y_0\) is replaced with \(M^* y_0\), where \(M\) is the given outer preconditioner.
Info
Good choices of inner solvers are the direct lineax.Cholesky and
the iterative lineax.CG.
__init__(inner_solver: lineax.AbstractLinearSolver[~_InnerSolverState])
¤
Arguments:
inner_solver: The solver to wrap. It should support solving positive definite systems or positive semidefinite systems
Info
In addition to these, lineax.Diagonal(well_posed=False) (below) also supports ill-posed problems.
Structure-exploiting solvers¤
These require special structure in the operator. (And will throw an error if passed an operator without that structure.) In return, they are able to solve the linear problem much more efficiently.
lineax.Cholesky(lineax.AbstractLinearSolver)
¤
Cholesky solver for linear systems. This is generally the preferred solver for positive or negative definite systems.
Equivalent to scipy.linalg.solve(..., assume_a="pos").
The operator must be square, nonsingular, and either positive or negative definite.
__init__()
¤
Arguments:
Nothing.
lineax.Diagonal(lineax.AbstractLinearSolver)
¤
Diagonal solver for linear systems.
Requires that the operator be diagonal. Then \(Ax = b\), with \(A = diag[a]\), is solved simply by doing an elementwise division \(x = b / a\).
This solver can handle singular operators (i.e. diagonal entries with value 0).
__init__(well_posed: bool = False, rcond: float | None = None)
¤
Arguments:
well_posed: ifFalse, then singular operators are accepted, and the pseudoinverse solution is returned. IfTruethen passing a singular operator will cause an error to be raised instead.rcond: the cutoff for handling zero entries on the diagonal. Defaults to machine precision timesN, whereNis the input (or output) size of the operator. Only used ifwell_posed=False
lineax.Triangular(lineax.AbstractLinearSolver)
¤
Triangular solver for linear systems.
The operator should either be lower triangular or upper triangular.
__init__()
¤
Arguments:
Nothing.
lineax.Tridiagonal(lineax.AbstractLinearSolver)
¤
Tridiagonal solver for linear systems, uses the LAPACK/cusparse implementation of Gaussian elimination with partial pivotting (which increases stability). .
__init__()
¤
Arguments:
Nothing.
Info
In addition to these, lineax.CG also requires special structure (positive or negative definiteness).
Iterative solvers¤
These solvers use only matrix-vector products, and do not require instantiating the whole matrix. This makes them good when used alongside e.g. lineax.JacobianLinearOperator or lineax.FunctionLinearOperator, which only provide matrix-vector products.
Warning
Note that lineax.BiCGStab and lineax.GMRES may fail to converge on some (typically non-sparse) problems.
lineax.CG(lineax.AbstractLinearSolver)
¤
Conjugate gradient solver for linear systems.
The operator should be positive or negative definite.
Equivalent to scipy.sparse.linalg.cg.
This supports the following options (as passed to
lx.linear_solve(..., options=...)).
-
preconditioner: A positive definitelineax.AbstractLinearOperatorto be used as preconditioner. Defaults tolineax.IdentityLinearOperator. This method uses left preconditioning, so it is the preconditioned residual that is minimized, though the actual termination criteria uses the un-preconditioned residual. -
y0: The initial estimate of the solution to the linear system. Defaults to all zeros.
__init__(rtol: float, atol: float, norm: Callable[[PyTree], Shaped[jaxlib._jax.Array, '']] = <function max_norm>, stabilise_every: int | None = 10, max_steps: int | None = None)
¤
Arguments:
rtol: Relative tolerance for terminating solve.atol: Absolute tolerance for terminating solve.norm: The norm to use when computing whether the error falls within the tolerance. Defaults to the max norm.stabilise_every: The conjugate gradient is an iterative method that produces candidate solutions \(x_1, x_2, \ldots\), and terminates once \(r_i = \| Ax_i - b \|\) is small enough. For computational efficiency, the values \(r_i\) are computed using other internal quantities, and not by directly evaluating the formula above. However, this computation of \(r_i\) is susceptible to drift due to limited floating-point precision. Everystabilise_everysteps, then \(r_i\) is computed directly using the formula above, in order to stabilise the computation.max_steps: The maximum number of iterations to run the solver for. If more steps than this are required, then the solve is halted with a failure.
lineax.BiCGStab(lineax.AbstractLinearSolver)
¤
Biconjugate gradient stabilised method for linear systems.
The operator should be square.
Equivalent to jax.scipy.sparse.linalg.bicgstab.
This supports the following options (as passed to
lx.linear_solve(..., options=...)).
preconditioner: Alineax.AbstractLinearOperatorto be used as a preconditioner. Defaults tolineax.IdentityLinearOperator. This method uses right preconditioning.y0: The initial estimate of the solution to the linear system. Defaults to all zeros.
__init__(rtol: float, atol: float, norm: Callable = <function max_norm>, max_steps: int | None = None)
¤
Arguments:
rtol: Relative tolerance for terminating solve.atol: Absolute tolerance for terminating solve.norm: The norm to use when computing whether the error falls within the tolerance. Defaults to the max norm.max_steps: The maximum number of iterations to run the solver for. If more steps than this are required, then the solve is halted with a failure.
lineax.GMRES(lineax.AbstractLinearSolver)
¤
GMRES solver for linear systems.
The operator should be square.
Similar to jax.scipy.sparse.linalg.gmres.
This supports the following options (as passed to
lx.linear_solve(..., options=...)).
preconditioner: Alineax.AbstractLinearOperatorto be used as preconditioner. Defaults tolineax.IdentityLinearOperator. This method uses left preconditioning, so it is the preconditioned residual that is minimized, though the actual termination criteria uses the un-preconditioned residual.y0: The initial estimate of the solution to the linear system. Defaults to all zeros.
__init__(rtol: float, atol: float, norm: Callable = <function max_norm>, max_steps: int | None = None, restart: int = 20, stagnation_iters: int = 20)
¤
Arguments:
rtol: Relative tolerance for terminating solve.atol: Absolute tolerance for terminating solve.norm: The norm to use when computing whether the error falls within the tolerance. Defaults to the max norm.max_steps: The maximum number of iterations to run the solver for. If more steps than this are required, then the solve is halted with a failure.restart: Size of the Krylov subspace built between restarts. The returned solution is the projection of the true solution onto this subpsace, so this direclty bounds the accuracy of the algorithm. Default is 20.stagnation_iters: The maximum number of iterations for which the solver may not decrease. If more thanstagnation_itersrestarts are performed without sufficient decrease in the residual, the algorithm is halted.