Linear least squares¤
The solution to a well-posed linear system \(Ax = b\) is given by \(x = A^{-1}b\). If the matrix is rectangular or not invertible, then we may generalise the notion of solution to \(x = A^{\dagger}b\), where \(A^{\dagger}\) denotes the Moore--Penrose pseudoinverse.
Lineax can handle problems of this type too.
Info
For reference: in core JAX, problems of this type are handled using jax.numpy.linalg.lstsq
.
Picking a solver¤
By default, the linear solve will fail. This will be a compile-time failure if using a rectangular matrix:
import jax.random as jr
import lineax as lx
vector = jr.normal(jr.PRNGKey(1), (3,))
rectangular_matrix = jr.normal(jr.PRNGKey(0), (3, 4))
rectangular_operator = lx.MatrixLinearOperator(rectangular_matrix)
lx.linear_solve(rectangular_operator, vector)
Or it will happen at run time if using a rank-deficient matrix:
deficient_matrix = jr.normal(jr.PRNGKey(0), (3, 3)).at[0].set(0)
deficient_operator = lx.MatrixLinearOperator(deficient_matrix)
lx.linear_solve(deficient_operator, vector)
Whilst linear least squares and pseudoinverse are a strict generalisation of linear solves and inverses (respectively), Lineax will not attempt to handle the ill-posed case automatically. This is because the algorithms for handling this case are much more computationally expensive.!
If your matrix may be rectangular, but is still known to be full rank, then you can set the solver to allow this case like so:
rectangular_solution = lx.linear_solve(
rectangular_operator, vector, solver=lx.AutoLinearSolver(well_posed=None)
)
print("rectangular_solution: ", rectangular_solution.value)
If your matrix may be either rectangular or rank-deficient, then you can set the solver to all this case like so:
deficient_solution = lx.linear_solve(
deficient_operator, vector, solver=lx.AutoLinearSolver(well_posed=False)
)
print("deficient_solution: ", deficient_solution.value)
Most users will want to use lineax.AutoLinearSolver
, and not think about the details of which algorithm is selected.
If you want to pick a particular algorithm, then that can be done too. lineax.QR
is capable of handling rectangular full-rank operators, and lineax.SVD
is capable of handling rank-deficient operators. (And in fact these are the algorithms that AutoLinearSolver
is selecting in the examples above.)
Differences from jax.numpy.linalg.lstsq
?¤
Lineax offers both speed and correctness advantages over the built-in algorithm. (This is partly because the built-in function has to have the same API as NumPy, so JAX is constrained in how it can be implemented.)
Speed (forward)¤
First, in the rectangular case, then the QR algorithm is much faster than the SVD algorithm:
import timeit
import jax
import jax.numpy as jnp
import numpy as np
matrix = jr.normal(jr.PRNGKey(0), (500, 200))
vector = jr.normal(jr.PRNGKey(1), (500,))
@jax.jit
def solve_jax(matrix, vector):
out, *_ = jnp.linalg.lstsq(matrix, vector)
return out
@jax.jit
def solve_lineax(matrix, vector):
operator = lx.MatrixLinearOperator(matrix)
solver = lx.QR() # or lx.AutoLinearSolver(well_posed=None)
solution = lx.linear_solve(operator, vector, solver)
return solution.value
solution_jax = solve_jax(matrix, vector)
solution_lineax = solve_lineax(matrix, vector)
with np.printoptions(threshold=10):
print("JAX solution:", solution_jax)
print("Lineax solution:", solution_lineax)
print()
time_jax = timeit.repeat(lambda: solve_jax(matrix, vector), number=1, repeat=10)
time_lineax = timeit.repeat(lambda: solve_lineax(matrix, vector), number=1, repeat=10)
print("JAX time:", min(time_jax))
print("Lineax time:", min(time_lineax))
Speed (gradients)¤
Lineax also uses a slightly more efficient autodifferentiation implementation, which ensures it is faster, even when both are using the SVD algorithm.
@jax.jit
@jax.grad
def grad_jax(matrix):
out, *_ = jnp.linalg.lstsq(matrix, vector)
return out.sum()
@jax.jit
@jax.grad
def grad_lineax(matrix):
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, lx.SVD())
return solution.value.sum()
gradients_jax = grad_jax(matrix)
gradients_lineax = grad_lineax(matrix)
with np.printoptions(threshold=10, edgeitems=2):
print("JAX gradients:", gradients_jax)
print("Lineax gradients:", gradients_lineax)
print()
time_jax = timeit.repeat(lambda: grad_jax(matrix), number=1, repeat=10)
time_lineax = timeit.repeat(lambda: grad_lineax(matrix), number=1, repeat=10)
print("JAX time:", min(time_jax))
print("Lineax time:", min(time_lineax))
Correctness (gradients)¤
Core JAX unfortunately has a bug that means it sometimes produces NaN gradients. Lineax does not:
@jax.jit
@jax.grad
def grad_jax(matrix):
out, *_ = jnp.linalg.lstsq(matrix, jnp.arange(3.0))
return out.sum()
@jax.jit
@jax.grad
def grad_lineax(matrix):
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, jnp.arange(3.0), lx.SVD())
return solution.value.sum()
print("JAX gradients:", grad_jax(jnp.eye(3)))
print("Lineax gradients:", grad_lineax(jnp.eye(3)))