# Manipulating linear operators¤

Lineax offers a sophisticated system of linear operators, supporting many operations.

## Arithmetic¤

To begin with, they support arithmetic, like addition and multiplication:

import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import numpy as np

np.set_printoptions(precision=3)

matrix = jnp.zeros((5, 5))
matrix = matrix.at[0, 4].set(3)  # top left corner
sparse_operator = lx.MatrixLinearOperator(matrix)

key0, key1, key = jr.split(jr.PRNGKey(0), 3)
diag = jr.normal(key0, (5,))
lower_diag = jr.normal(key0, (4,))
upper_diag = jr.normal(key0, (4,))
tridiag_operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)

identity_operator = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((5,), jnp.float32))
print((sparse_operator + tridiag_operator).as_matrix())
[[-1.149  0.963  0.     0.     3.   ]
[ 0.963 -2.007  0.155  0.     0.   ]
[ 0.     0.155  0.988 -0.261  0.   ]
[ 0.     0.    -0.261  0.931  0.899]
[ 0.     0.     0.     0.899 -0.288]]

print((tridiag_operator - 100 * identity_operator).as_matrix())
[[-101.149    0.963    0.       0.       0.   ]
[   0.963 -102.007    0.155    0.       0.   ]
[   0.       0.155  -99.012   -0.261    0.   ]
[   0.       0.      -0.261  -99.069    0.899]
[   0.       0.       0.       0.899 -100.288]]

Or they can be composed together. (I.e. matrix multiplication.)

print((tridiag_operator @ sparse_operator).as_matrix())
[[ 0.     0.     0.     0.    -3.447]
[ 0.     0.     0.     0.     2.888]
[ 0.     0.     0.     0.     0.   ]
[ 0.     0.     0.     0.     0.   ]
[ 0.     0.     0.     0.     0.   ]]

Or they can be transposed:

print(sparse_operator.transpose().as_matrix())  # or sparse_operator.T will work
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[3. 0. 0. 0. 0.]]

## Different operator types¤

Lineax has many different operator types:

• We've already seen some general examples above, like lineax.MatrixLinearOperator.
• We've already seen some structured examples above, like lineax.TridiagonalLinearOperator.
• Given a function $$f \colon \mathbb{R}^n \to \mathbb{R}^m$$ and a point $$x \in \mathbb{R}^n$$, then lineax.JacobianLinearOperator represents the Jacobian $$\frac{\mathrm{d}f}{\mathrm{d}x}(x) \in \mathbb{R}^{n \times m}$$.
• Given a linear function $$g \colon \mathbb{R}^n \to \mathbb{R}^m$$, then lineax.FunctionLinearOperator represents the matrix corresponding to this linear function, i.e. the unique matrix $$A$$ for which $$g(x) = Ax$$.
• etc!

See the operators page for details on all supported operators.

As above these can be freely combined:

from jaxtyping import Array, Float  # https://github.com/google/jaxtyping

def f(y: Float[Array, "3"], args) -> Float[Array, "3"]:
y0, y1, y2 = y
f0 = 5 * y0 + y1**2
f1 = y1 - y2 + 5
f2 = y0 / (1 + 5 * y2**2)
return jnp.stack([f0, f1, f2])

def g(y: Float[Array, "3"]) -> Float[Array, "3"]:
# Must be linear!
y0, y1, y2 = y
f0 = y0 - y2
f1 = 0.0
f2 = 5 * y1
return jnp.stack([f0, f1, f2])

y = jnp.array([1.0, 2.0, 3.0])
in_structure = jax.eval_shape(lambda: y)
jac_operator = lx.JacobianLinearOperator(f, y, args=None)
fn_operator = lx.FunctionLinearOperator(g, in_structure)
identity_operator = lx.IdentityLinearOperator(in_structure)

operator = jac_operator @ fn_operator + 0.9 * identity_operator

This composition does not instantiate a matrix for them by default. (This is sometimes important for efficiency when working with many operators.) Instead, the composition is stored as another linear operator:

import equinox as eqx  # https://github.com/patrick-kidger/equinox

truncate_leaf = lambda x: x in (jac_operator, fn_operator, identity_operator)
eqx.tree_pprint(operator, truncate_leaf=truncate_leaf)
operator1=ComposedLinearOperator(
operator1=JacobianLinearOperator(...),
operator2=FunctionLinearOperator(...)
),
operator2=MulLinearOperator(
operator=IdentityLinearOperator(...),
scalar=f32[]
)
)

If you want to materialise them into a matrix, then this can be done:

operator.as_matrix()
Array([[ 5.9  ,  0.   , -5.   ],
[ 0.   , -4.1  ,  0.   ],
[ 0.022, -0.071,  0.878]], dtype=float32)

Which can in turn be treated as another linear operator, if desired:

operator_fully_materialised = lx.MatrixLinearOperator(operator.as_matrix())
eqx.tree_pprint(operator_fully_materialised, short_arrays=False)
MatrixLinearOperator(
matrix=Array([[ 5.9  ,  0.   , -5.   ],
[ 0.   , -4.1  ,  0.   ],
[ 0.022, -0.071,  0.878]], dtype=float32),
tags=frozenset()
)