Skip to content


How does this differ from jax.numpy.solve, jax.scipy.{...} etc.?¤

Lineax offers several improvements. Most notably:

  • Several new solvers. For example, lineax.QR has no counterpart in core JAX. (And it is much faster than jax.numpy.linalg.lstsq, which is the closest equivalent, and uses an SVD decomposition instead.)

  • Several new operators. For example, lineax.JacobianLinearOperator has no counterpart in core JAX.

  • A consistent API. The built-in JAX operations all differ from each other slightly, and are split across jax.numpy, jax.scipy, and jax.scipy.sparse.

  • Numerically stable gradients. The existing JAX implementations will sometimes return NaNs!

  • Some faster compile times and run times in a few places.

Most of these are because JAX aims to mimic the existing NumPy/SciPy APIs. (I.e. it's not JAX's fault that it doesn't take the approach that Lineax does!)

How do I represent a {lower, upper} triangular matrix?¤

Typically: create a full matrix, with the {lower, upper} part containing your values, and the converse {upper, lower} part containing all zeros. Then use, e.g., operator = lx.MatrixLinearOperator(matrix, lx.lower_triangular_tag).

This is the most efficient way to store a triangular matrix in JAX's ndarray-based programming model.

What about other operations from linear algebra? (Determinants, eigenvalues, etc.)¤

See jax.numpy.linalg and jax.scipy.linalg.

How do I solve multiple systems of equations (i.e. AX = B)?¤

Solvers implemented in Lineax target single systems of linear equations (i.e., Ax = b), however, using jax.vmap or equinox.filter_vmap, it can solve multiple systems with minimal effort.

multi_linear_solve = eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1))
#  or    
multi_linear_solve = jax.vmap(lx.linear_solve, in_axes=(None, 1))