Skip to content

Manipulating linear operators¤

Lineax offers a sophisticated system of linear operators, supporting many operations.

Arithmetic¤

To begin with, they support arithmetic, like addition and multiplication:

import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import numpy as np


np.set_printoptions(precision=3)

matrix = jnp.zeros((5, 5))
matrix = matrix.at[0, 4].set(3)  # top left corner
sparse_operator = lx.MatrixLinearOperator(matrix)

key0, key1, key = jr.split(jr.PRNGKey(0), 3)
diag = jr.normal(key0, (5,))
lower_diag = jr.normal(key0, (4,))
upper_diag = jr.normal(key0, (4,))
tridiag_operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)

identity_operator = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((5,), jnp.float32))
print((sparse_operator + tridiag_operator).as_matrix())
[[-1.149  0.963  0.     0.     3.   ]
 [ 0.963 -2.007  0.155  0.     0.   ]
 [ 0.     0.155  0.988 -0.261  0.   ]
 [ 0.     0.    -0.261  0.931  0.899]
 [ 0.     0.     0.     0.899 -0.288]]

print((tridiag_operator - 100 * identity_operator).as_matrix())
[[-101.149    0.963    0.       0.       0.   ]
 [   0.963 -102.007    0.155    0.       0.   ]
 [   0.       0.155  -99.012   -0.261    0.   ]
 [   0.       0.      -0.261  -99.069    0.899]
 [   0.       0.       0.       0.899 -100.288]]

Or they can be composed together. (I.e. matrix multiplication.)

print((tridiag_operator @ sparse_operator).as_matrix())
[[ 0.     0.     0.     0.    -3.447]
 [ 0.     0.     0.     0.     2.888]
 [ 0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.   ]]

Or they can be transposed:

print(sparse_operator.transpose().as_matrix())  # or sparse_operator.T will work
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [3. 0. 0. 0. 0.]]

Different operator types¤

Lineax has many different operator types:

See the operators page for details on all supported operators.

As above these can be freely combined:

from jaxtyping import Array, Float  # https://github.com/google/jaxtyping


def f(y: Float[Array, "3"], args) -> Float[Array, "3"]:
    y0, y1, y2 = y
    f0 = 5 * y0 + y1**2
    f1 = y1 - y2 + 5
    f2 = y0 / (1 + 5 * y2**2)
    return jnp.stack([f0, f1, f2])


def g(y: Float[Array, "3"]) -> Float[Array, "3"]:
    # Must be linear!
    y0, y1, y2 = y
    f0 = y0 - y2
    f1 = 0.0
    f2 = 5 * y1
    return jnp.stack([f0, f1, f2])


y = jnp.array([1.0, 2.0, 3.0])
in_structure = jax.eval_shape(lambda: y)
jac_operator = lx.JacobianLinearOperator(f, y, args=None)
fn_operator = lx.FunctionLinearOperator(g, in_structure)
identity_operator = lx.IdentityLinearOperator(in_structure)

operator = jac_operator @ fn_operator + 0.9 * identity_operator

This composition does not instantiate a matrix for them by default. (This is sometimes important for efficiency when working with many operators.) Instead, the composition is stored as another linear operator:

import equinox as eqx  # https://github.com/patrick-kidger/equinox


truncate_leaf = lambda x: x in (jac_operator, fn_operator, identity_operator)
eqx.tree_pprint(operator, truncate_leaf=truncate_leaf)
AddLinearOperator(
  operator1=ComposedLinearOperator(
    operator1=JacobianLinearOperator(...),
    operator2=FunctionLinearOperator(...)
  ),
  operator2=MulLinearOperator(
    operator=IdentityLinearOperator(...),
    scalar=f32[]
  )
)

If you want to materialise them into a matrix, then this can be done:

operator.as_matrix()
Array([[ 5.9  ,  0.   , -5.   ],
       [ 0.   , -4.1  ,  0.   ],
       [ 0.022, -0.071,  0.878]], dtype=float32)

Which can in turn be treated as another linear operator, if desired:

operator_fully_materialised = lx.MatrixLinearOperator(operator.as_matrix())
eqx.tree_pprint(operator_fully_materialised, short_arrays=False)
MatrixLinearOperator(
  matrix=Array([[ 5.9  ,  0.   , -5.   ],
       [ 0.   , -4.1  ,  0.   ],
       [ 0.022, -0.071,  0.878]], dtype=float32),
  tags=frozenset()
)