Structured matricesยค
Lineax can also be used with matrices known to exhibit special structure, e.g. tridiagonal matrices or positive definite matrices.
Typically, that means using a particular operator type:
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
diag = jnp.array([4.0, -0.5, 7.0, 1.0])
lower_diag = jnp.array([1.0, 3.0, -0.7])
upper_diag = jnp.array([2.0, -1.0, -5.0])
operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)
print(operator.as_matrix())
vector = jnp.array([1.0, -0.5, 2.0, 0.8])
# Will automatically dispatch to a tridiagonal solver.
solution = lx.linear_solve(operator, vector)
If you're uncertain which solver is being dispatched to, then you can check:
default_solver = lx.AutoLinearSolver(well_posed=True)
print(default_solver.select_solver(operator))
If you want to enforce that a particular solver is used, then it can be passed manually:
solution = lx.linear_solve(operator, vector, solver=lx.Tridiagonal())
Trying to use a solver with an unsupported operator will raise an error:
not_tridiagonal_matrix = jr.normal(jr.PRNGKey(0), (4, 4))
not_tridiagonal_operator = lx.MatrixLinearOperator(not_tridiagonal_matrix)
solution = lx.linear_solve(not_tridiagonal_operator, vector, solver=lx.Tridiagonal())
Besides using a particular operator type, the structure of the matrix can also be expressed by adding particular tags. These tags act as a manual override mechanism, and the values of the matrix are not checked.
For example, let's construct a positive definite matrix:
matrix = jr.normal(jr.PRNGKey(0), (4, 4))
operator = lx.MatrixLinearOperator(matrix.T @ matrix)
Unfortunately, Lineax has no way of knowing that this matrix is positive definite. It can solve the system, but it will not use a solver that is adapted to exploit the extra structure:
solution = lx.linear_solve(operator, vector)
print(default_solver.select_solver(operator))
But if we add a tag:
operator = lx.MatrixLinearOperator(matrix.T @ matrix, lx.positive_semidefinite_tag)
solution2 = lx.linear_solve(operator, vector)
print(default_solver.select_solver(operator))
Then a more efficient solver can be selected. We can check that the solutions returned from these two approaches are equal:
print(solution.value)
print(solution2.value)