Using only matrix-vector operationsยค
When solving a linear system \(Ax = b\), it is relatively common not to have immediate access to the full matrix \(A\), but only to a function \(F(x) = Ax\) computing the matrix-vector product. (We could compute \(A\) from \(F\), but is the matrix is large then this may be very inefficient.)
Example: Newton's method
For example, this comes up when using Newton's method. In this case, we have a function \(f \colon \mathbb{R}^n \to \mathbb{R}^n\), and wish to find the \(\delta \in \mathbb{R}^n\) for which \(\frac{\mathrm{d}f}{\mathrm{d}y}(y) \; \delta = -f(y)\). (Where \(\frac{\mathrm{d}f}{\mathrm{d}y}(y) \in \mathbb{R}^{n \times n}\) is a matrix: it is the Jacobian of \(f\).)
In this case it is possible to use forward-mode autodifferentiation to evaluate \(F(x) = \frac{\mathrm{d}f}{\mathrm{d}y}(y) \; x\), without ever instantiating the whole Jacobian \(\frac{\mathrm{d}f}{\mathrm{d}y}(y)\). Indeed, JAX has a Jacobian-vector product function for exactly this purpose.
f = ...
y = ...
def F(x):
"""Computes (df/dy) @ x."""
_, out = jax.jvp(f, (y,), (x,))
return out
Solving a linear system using only matrix-vector operations
Lineax offers iterative solvers, which are capable of solving a linear system knowing only its matrix-vector products.
import jax.numpy as jnp
import lineax as lx
from jaxtyping import Array, Float # https://github.com/google/jaxtyping
def f(y: Float[Array, "3"], args) -> Float[Array, "3"]:
y0, y1, y2 = y
f0 = 5 * y0 + y1**2
f1 = y1 - y2 + 5
f2 = y0 / (1 + 5 * y2**2)
return jnp.stack([f0, f1, f2])
y = jnp.array([1.0, 2.0, 3.0])
operator = lx.JacobianLinearOperator(f, y, args=None)
vector = f(y, args=None)
solver = lx.NormalCG(rtol=1e-6, atol=1e-6)
solution = lx.linear_solve(operator, vector, solver)
Warning
Note that iterative solvers are something of a "last resort", and they are not suitable for all problems.
- CG requires that the problem be positive or negative semidefinite.
- Normalised CG (this is CG applied to the "normal equations" \((A^\top A) x = (A^\top b)\); note that \(A^\top A\) is always positive semidefinite) squares the condition number of \(A\). In practice this means it may produce low-accuracy results if used with matrices with high condition number.
- BiCGStab and GMRES will fail on many problems. They are primarily meant as specialised tools for e.g. the matrices that arise when solving elliptic systems.