Skip to content

Structured matricesยค

Lineax can also be used with matrices known to exhibit special structure, e.g. tridiagonal matrices or positive definite matrices.

Typically, that means using a particular operator type:

import jax.numpy as jnp
import jax.random as jr
import lineax as lx


diag = jnp.array([4.0, -0.5, 7.0, 1.0])
lower_diag = jnp.array([1.0, 3.0, -0.7])
upper_diag = jnp.array([2.0, -1.0, -5.0])

operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)
print(operator.as_matrix())
[[ 4.   2.   0.   0. ]
 [ 1.  -0.5 -1.   0. ]
 [ 0.   3.   7.  -5. ]
 [ 0.   0.  -0.7  1. ]]

vector = jnp.array([1.0, -0.5, 2.0, 0.8])
# Will automatically dispatch to a tridiagonal solver.
solution = lx.linear_solve(operator, vector)

If you're uncertain which solver is being dispatched to, then you can check:

default_solver = lx.AutoLinearSolver(well_posed=True)
print(default_solver.select_solver(operator))
Tridiagonal()

If you want to enforce that a particular solver is used, then it can be passed manually:

solution = lx.linear_solve(operator, vector, solver=lx.Tridiagonal())

Trying to use a solver with an unsupported operator will raise an error:

not_tridiagonal_matrix = jr.normal(jr.PRNGKey(0), (4, 4))
not_tridiagonal_operator = lx.MatrixLinearOperator(not_tridiagonal_matrix)
solution = lx.linear_solve(not_tridiagonal_operator, vector, solver=lx.Tridiagonal())
ValueError: `Tridiagonal` may only be used for linear solves with tridiagonal matrices

Besides using a particular operator type, the structure of the matrix can also be expressed by adding particular tags. These tags act as a manual override mechanism, and the values of the matrix are not checked.

For example, let's construct a positive definite matrix:

matrix = jr.normal(jr.PRNGKey(0), (4, 4))
operator = lx.MatrixLinearOperator(matrix.T @ matrix)

Unfortunately, Lineax has no way of knowing that this matrix is positive definite. It can solve the system, but it will not use a solver that is adapted to exploit the extra structure:

solution = lx.linear_solve(operator, vector)
print(default_solver.select_solver(operator))
LU()

But if we add a tag:

operator = lx.MatrixLinearOperator(matrix.T @ matrix, lx.positive_semidefinite_tag)
solution2 = lx.linear_solve(operator, vector)
print(default_solver.select_solver(operator))
Cholesky()

Then a more efficient solver can be selected. We can check that the solutions returned from these two approaches are equal:

print(solution.value)
print(solution2.value)
[ 1.400575   -0.41042092  0.5313305   0.28422552]
[ 1.4005749  -0.41042086  0.53133047  0.2842255 ]