Skip to content

Linear least squares¤

The solution to a well-posed linear system \(Ax = b\) is given by \(x = A^{-1}b\). If the matrix is rectangular or not invertible, then we may generalise the notion of solution to \(x = A^{\dagger}b\), where \(A^{\dagger}\) denotes the Moore--Penrose pseudoinverse.

Lineax can handle problems of this type too.


For reference: in core JAX, problems of this type are handled using jax.numpy.linalg.lstsq.

Picking a solver¤

By default, the linear solve will fail. This will be a compile-time failure if using a rectangular matrix:

import jax.random as jr
import lineax as lx

vector = jr.normal(jr.PRNGKey(1), (3,))

rectangular_matrix = jr.normal(jr.PRNGKey(0), (3, 4))
rectangular_operator = lx.MatrixLinearOperator(rectangular_matrix)
lx.linear_solve(rectangular_operator, vector)
ValueError: Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.

Or it will happen at run time if using a rank-deficient matrix:

deficient_matrix = jr.normal(jr.PRNGKey(0), (3, 3)).at[0].set(0)
deficient_operator = lx.MatrixLinearOperator(deficient_matrix)
lx.linear_solve(deficient_operator, vector)
XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

Whilst linear least squares and pseudoinverse are a strict generalisation of linear solves and inverses (respectively), Lineax will not attempt to handle the ill-posed case automatically. This is because the algorithms for handling this case are much more computationally expensive.!

If your matrix may be rectangular, but is still known to be full rank, then you can set the solver to allow this case like so:

rectangular_solution = lx.linear_solve(
    rectangular_operator, vector, solver=lx.AutoLinearSolver(well_posed=None)
print("rectangular_solution: ", rectangular_solution.value)
rectangular_solution:  [-0.3214848  -0.75565964 -0.6034579  -0.01326615]

If your matrix may be either rectangular or rank-deficient, then you can set the solver to all this case like so:

deficient_solution = lx.linear_solve(
    deficient_operator, vector, solver=lx.AutoLinearSolver(well_posed=False)
print("deficient_solution: ", deficient_solution.value)
deficient_solution:  [ 0.06046088 -1.0412765   0.8860444 ]

Most users will want to use lineax.AutoLinearSolver, and not think about the details of which algorithm is selected.

If you want to pick a particular algorithm, then that can be done too. lineax.QR is capable of handling rectangular full-rank operators, and lineax.SVD is capable of handling rank-deficient operators. (And in fact these are the algorithms that AutoLinearSolver is selecting in the examples above.)

Differences from jax.numpy.linalg.lstsq?¤

Lineax offers both speed and correctness advantages over the built-in algorithm. (This is partly because the built-in function has to have the same API as NumPy, so JAX is constrained in how it can be implemented.)

Speed (forward)¤

First, in the rectangular case, then the QR algorithm is much faster than the SVD algorithm:

import timeit

import jax
import jax.numpy as jnp
import numpy as np

matrix = jr.normal(jr.PRNGKey(0), (500, 200))
vector = jr.normal(jr.PRNGKey(1), (500,))

def solve_jax(matrix, vector):
    out, *_ = jnp.linalg.lstsq(matrix, vector)
    return out

def solve_lineax(matrix, vector):
    operator = lx.MatrixLinearOperator(matrix)
    solver = lx.QR()  # or lx.AutoLinearSolver(well_posed=None)
    solution = lx.linear_solve(operator, vector, solver)
    return solution.value

solution_jax = solve_jax(matrix, vector)
solution_lineax = solve_lineax(matrix, vector)
with np.printoptions(threshold=10):
    print("JAX solution:", solution_jax)
    print("Lineax solution:", solution_lineax)
time_jax = timeit.repeat(lambda: solve_jax(matrix, vector), number=1, repeat=10)
time_lineax = timeit.repeat(lambda: solve_lineax(matrix, vector), number=1, repeat=10)
print("JAX time:", min(time_jax))
print("Lineax time:", min(time_lineax))
JAX solution: [-0.10002219  0.09477127 -0.10846332 ... -0.08007179 -0.01216239
 -0.030862  ]
Lineax solution: [-0.1000222   0.0947713  -0.10846333 ... -0.08007187 -0.01216241

JAX time: 0.011344402999384329
Lineax time: 0.0028611960005946457

Speed (gradients)¤

Lineax also uses a slightly more efficient autodifferentiation implementation, which ensures it is faster, even when both are using the SVD algorithm.

def grad_jax(matrix):
    out, *_ = jnp.linalg.lstsq(matrix, vector)
    return out.sum()

def grad_lineax(matrix):
    operator = lx.MatrixLinearOperator(matrix)
    solution = lx.linear_solve(operator, vector, lx.SVD())
    return solution.value.sum()

gradients_jax = grad_jax(matrix)
gradients_lineax = grad_lineax(matrix)
with np.printoptions(threshold=10, edgeitems=2):
    print("JAX gradients:", gradients_jax)
    print("Lineax gradients:", gradients_lineax)
time_jax = timeit.repeat(lambda: grad_jax(matrix), number=1, repeat=10)
time_lineax = timeit.repeat(lambda: grad_lineax(matrix), number=1, repeat=10)
print("JAX time:", min(time_jax))
print("Lineax time:", min(time_lineax))
JAX gradients: [[-1.75446249e-03  2.00700224e-03 ... -3.16517282e-04 -6.08515576e-04]
 [ 1.81865180e-04  4.51280124e-04 ... -1.64618701e-04 -6.53692259e-05]
 [-7.27269216e-04  1.27710134e-03 ... -2.64510425e-04 -3.38940619e-04]
 [ 6.55723223e-03 -3.18011409e-03 ... -1.10758876e-04  1.43246143e-03]]
Lineax gradients: [[-1.7544631e-03  2.0070139e-03 ... -3.1653541e-04 -6.0847402e-04]
 [ 1.8186278e-04  4.5128341e-04 ... -1.6459504e-04 -6.5359738e-05]
 [-7.2721508e-04  1.2771402e-03 ... -2.6450949e-04 -3.3894143e-04]
 [ 6.5572355e-03 -3.1801097e-03 ... -1.1071599e-04  1.4324478e-03]]

JAX time: 0.016591553001489956
Lineax time: 0.012212782999995397

Correctness (gradients)¤

Core JAX unfortunately has a bug that means it sometimes produces NaN gradients. Lineax does not:

def grad_jax(matrix):
    out, *_ = jnp.linalg.lstsq(matrix, jnp.arange(3.0))
    return out.sum()

def grad_lineax(matrix):
    operator = lx.MatrixLinearOperator(matrix)
    solution = lx.linear_solve(operator, jnp.arange(3.0), lx.SVD())
    return solution.value.sum()

print("JAX gradients:", grad_jax(jnp.eye(3)))
print("Lineax gradients:", grad_lineax(jnp.eye(3)))
JAX gradients: [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
Lineax gradients: [[ 0. -1. -2.]
 [ 0. -1. -2.]
 [ 0. -1. -2.]]