Skip to content

Using only matrix-vector operationsยค

When solving a linear system \(Ax = b\), it is relatively common not to have immediate access to the full matrix \(A\), but only to a function \(F(x) = Ax\) computing the matrix-vector product. (We could compute \(A\) from \(F\), but is the matrix is large then this may be very inefficient.)

Example: Newton's method

For example, this comes up when using Newton's method. In this case, we have a function \(f \colon \mathbb{R}^n \to \mathbb{R}^n\), and wish to find the \(\delta \in \mathbb{R}^n\) for which \(\frac{\mathrm{d}f}{\mathrm{d}y}(y) \; \delta = -f(y)\). (Where \(\frac{\mathrm{d}f}{\mathrm{d}y}(y) \in \mathbb{R}^{n \times n}\) is a matrix: it is the Jacobian of \(f\).)

In this case it is possible to use forward-mode autodifferentiation to evaluate \(F(x) = \frac{\mathrm{d}f}{\mathrm{d}y}(y) \; x\), without ever instantiating the whole Jacobian \(\frac{\mathrm{d}f}{\mathrm{d}y}(y)\). Indeed, JAX has a Jacobian-vector product function for exactly this purpose.

f = ...
y = ...

def F(x):
    """Computes (df/dy) @ x."""
    _, out = jax.jvp(f, (y,), (x,))
    return out

Solving a linear system using only matrix-vector operations

Lineax offers iterative solvers, which are capable of solving a linear system knowing only its matrix-vector products.

import jax.numpy as jnp
import lineax as lx
from jaxtyping import Array, Float  # https://github.com/google/jaxtyping


def f(y: Float[Array, "3"], args) -> Float[Array, "3"]:
    y0, y1, y2 = y
    f0 = 5 * y0 + y1**2
    f1 = y1 - y2 + 5
    f2 = y0 / (1 + 5 * y2**2)
    return jnp.stack([f0, f1, f2])


y = jnp.array([1.0, 2.0, 3.0])
operator = lx.JacobianLinearOperator(f, y, args=None)
vector = f(y, args=None)
solver = lx.NormalCG(rtol=1e-6, atol=1e-6)
solution = lx.linear_solve(operator, vector, solver)

Warning

Note that iterative solvers are something of a "last resort", and they are not suitable for all problems.

  • CG requires that the problem be positive or negative semidefinite.
  • Normalised CG (this is CG applied to the "normal equations" \((A^\top A) x = (A^\top b)\); note that \(A^\top A\) is always positive semidefinite) squares the condition number of \(A\). In practice this means it may produce low-accuracy results if used with matrices with high condition number.
  • BiCGStab and GMRES will fail on many problems. They are primarily meant as specialised tools for e.g. the matrices that arise when solving elliptic systems.