import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import numpy as np
np.set_printoptions(precision=3)
matrix = jnp.zeros((5, 5))
matrix = matrix.at[0, 4].set(3) # top left corner
sparse_operator = lx.MatrixLinearOperator(matrix)
key0, key1, key = jr.split(jr.PRNGKey(0), 3)
diag = jr.normal(key0, (5,))
lower_diag = jr.normal(key0, (4,))
upper_diag = jr.normal(key0, (4,))
tridiag_operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)
identity_operator = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((5,), jnp.float32))
print((sparse_operator + tridiag_operator).as_matrix())
print((tridiag_operator - 100 * identity_operator).as_matrix())
Or they can be composed together. (I.e. matrix multiplication.)
print((tridiag_operator @ sparse_operator).as_matrix())
Or they can be transposed:
print(sparse_operator.transpose().as_matrix()) # or sparse_operator.T will work
Different operator types¤
Lineax has many different operator types:
- We've already seen some general examples above, like
lineax.MatrixLinearOperator
. - We've already seen some structured examples above, like
lineax.TridiagonalLinearOperator
. - Given a function \(f \colon \mathbb{R}^n \to \mathbb{R}^m\) and a point \(x \in \mathbb{R}^n\), then
lineax.JacobianLinearOperator
represents the Jacobian \(\frac{\mathrm{d}f}{\mathrm{d}x}(x) \in \mathbb{R}^{n \times m}\). - Given a linear function \(g \colon \mathbb{R}^n \to \mathbb{R}^m\), then
lineax.FunctionLinearOperator
represents the matrix corresponding to this linear function, i.e. the unique matrix \(A\) for which \(g(x) = Ax\). - etc!
See the operators page for details on all supported operators.
As above these can be freely combined:
from jaxtyping import Array, Float # https://github.com/google/jaxtyping
def f(y: Float[Array, "3"], args) -> Float[Array, "3"]:
y0, y1, y2 = y
f0 = 5 * y0 + y1**2
f1 = y1 - y2 + 5
f2 = y0 / (1 + 5 * y2**2)
return jnp.stack([f0, f1, f2])
def g(y: Float[Array, "3"]) -> Float[Array, "3"]:
# Must be linear!
y0, y1, y2 = y
f0 = y0 - y2
f1 = 0.0
f2 = 5 * y1
return jnp.stack([f0, f1, f2])
y = jnp.array([1.0, 2.0, 3.0])
in_structure = jax.eval_shape(lambda: y)
jac_operator = lx.JacobianLinearOperator(f, y, args=None)
fn_operator = lx.FunctionLinearOperator(g, in_structure)
identity_operator = lx.IdentityLinearOperator(in_structure)
operator = jac_operator @ fn_operator + 0.9 * identity_operator
This composition does not instantiate a matrix for them by default. (This is sometimes important for efficiency when working with many operators.) Instead, the composition is stored as another linear operator:
import equinox as eqx # https://github.com/patrick-kidger/equinox
truncate_leaf = lambda x: x in (jac_operator, fn_operator, identity_operator)
eqx.tree_pprint(operator, truncate_leaf=truncate_leaf)
If you want to materialise them into a matrix, then this can be done:
operator.as_matrix()
Which can in turn be treated as another linear operator, if desired:
operator_fully_materialised = lx.MatrixLinearOperator(operator.as_matrix())
eqx.tree_pprint(operator_fully_materialised, short_arrays=False)