Skip to content

Linear operators¤

We often talk about solving a linear system \(Ax = b\), where \(A \in \mathbb{R}^{n \times m}\) is a matrix, \(b \in \mathbb{R}^n\) is a vector, and \(x \in \mathbb{R}^m\) is our desired solution.

The linear operators described on this page are ways of describing the matrix \(A\). The simplest is lineax.MatrixLinearOperator, which simply holds the matrix \(A\) directly.

Meanwhile if \(A\) is diagonal, then there is also lineax.DiagonalLinearOperator: for efficiency this only stores the diagonal of \(A\).

Or, perhaps we only have a function \(F : \mathbb{R}^m \to \mathbb{R}^n\) such that \(F(x) = Ax\). Whilst we could use \(F\) to materialise the whole matrix \(A\) and then store it in a lineax.MatrixLinearOperator, that may be very memory intensive. Instead, we may prefer to use lineax.FunctionLinearOperator. Many linear solvers (e.g. lineax.CG) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix \(A\).

lineax.AbstractLinearOperator

lineax.AbstractLinearOperator ¤

Abstract base class for all linear operators.

Linear operators can act between PyTrees. Each AbstractLinearOperator is thought of as a linear function X -> Y, where each element of X is as PyTree of floating-point JAX arrays, and each element of Y is a PyTree of floating-point JAX arrays.

Abstract linear operators support some operations:

op1 + op2  # addition of two operators
op1 @ op2  # composition of two operators.
op1 * 3.2  # multiplication by a scalar
op1 / 3.2  # division by a scalar

mv(self, vector: PyTree[Inexact[Array, '_b']]) -> PyTree[Inexact[Array, '_a']] abstractmethod ¤

Computes a matrix-vector product between this operator and a vector.

Arguments:

  • vector: Should be some PyTree of floating-point arrays, whose structure should match self.in_structure().

Returns:

A PyTree of floating-point arrays, with structure that matches self.out_structure().

as_matrix(self) -> Inexact[Array, 'a b'] abstractmethod ¤

Materialises this linear operator as a matrix.

Note that this can be a computationally (time and/or memory) expensive operation, as many linear operators are defined implicitly, e.g. in terms of their action on a vector.

Arguments: None.

Returns:

A 2-dimensional floating-point JAX array.

transpose(self) -> AbstractLinearOperator abstractmethod ¤

Transposes this linear operator.

This can be called as either operator.T or operator.transpose().

Arguments: None.

Returns:

Another lineax.AbstractLinearOperator.

in_structure(self) -> PyTree[jax.ShapeDtypeStruct] abstractmethod ¤

Returns the expected input structure of this linear operator.

Arguments: None.

Returns:

A PyTree of jax.ShapeDtypeStruct.

out_structure(self) -> PyTree[jax.ShapeDtypeStruct] abstractmethod ¤

Returns the expected output structure of this linear operator.

Arguments: None.

Returns:

A PyTree of jax.ShapeDtypeStruct.

in_size(self) -> int ¤

Returns the total number of scalars in the input of this linear operator.

That is, the dimensionality of its input space.

Arguments: None.

Returns: An integer.

out_size(self) -> int ¤

Returns the total number of scalars in the output of this linear operator.

That is, the dimensionality of its output space.

Arguments: None.

Returns: An integer.

¤

¤
¤
¤
¤
¤
¤
¤

lineax.MatrixLinearOperator (AbstractLinearOperator) ¤

Wraps a 2-dimensional JAX array into a linear operator.

If the matrix has shape (a, b) then matrix-vector multiplication (self.mv) is defined in the usual way: as performing a matrix-vector that accepts a vector of shape (a,) and returns a vector of shape (b,).

__init__(self, matrix: Shaped[Array, 'a b'], tags: Union[object, frozenset[object]] = ()) ¤

Arguments:

  • matrix: a two-dimensional JAX array. For an array with shape (a, b) then this operator can perform matrix-vector products on a vector of shape (b,) to return a vector of shape (a,).
  • tags: any tags indicating whether this matrix has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

lineax.DiagonalLinearOperator (AbstractLinearOperator) ¤

As lineax.MatrixLinearOperator, but for specifically a diagonal matrix.

Only the diagonal of the matrix is stored (for memory efficiency). Matrix-vector products are computed by doing a pointwise diagonal * vector, rather than a full matrix @ vector (for speed).

__init__(self, diagonal: Shaped[Array, 'size']) ¤

Arguments:

  • diagonal: A rank-one JAX array, i.e. of shape (a,) for some a. This is the diagonal of the matrix.

lineax.TridiagonalLinearOperator (AbstractLinearOperator) ¤

As lineax.MatrixLinearOperator, but for specifically a tridiagonal matrix.

__init__(self, diagonal: Inexact[Array, 'size'], lower_diagonal: Inexact[Array, 'size-1'], upper_diagonal: Inexact[Array, 'size-1']) ¤

Arguments:

  • diagonal: A rank-one JAX array. This is the diagonal of the matrix.
  • lower_diagonal: A rank-one JAX array. This is the lower diagonal of the matrix.
  • upper_diagonal: A rank-one JAX array. This is the upper diagonal of the matrix.

If diagonal has shape (a,) then lower_diagonal and upper_diagonal should both have shape (a - 1,).


lineax.PyTreeLinearOperator (AbstractLinearOperator) ¤

Represents a PyTree of floating-point JAX arrays as a linear operator.

This is basically a generalisation of lineax.MatrixLinearOperator, from taking just a single array to take a PyTree-of-arrays. (And likewise from returning a single array to returning a PyTree-of-arrays.)

Specifically, suppose we want this to be a linear operator X -> Y, for which elements of X are PyTrees with structure T whose ith leaf is a floating-point JAX array of shape x_shape_i, and elements of Y are PyTrees with structure S whose jth leaf is a floating-point JAX array of has shape y_shape_j. Then the input PyTree should have structure T-compose-S, and its (i, j)-th leaf should be a floating-point JAX array of shape (*x_shape_i, *y_shape_j).

Example

# Suppose `x` is a member of our input space, with the following pytree
# structure:
eqx.tree_pprint(x)  # [f32[5, 9], f32[3]]

# Suppose `y` is a member of our output space, with the following pytree
# structure:
eqx.tree_pprint(y)
# {"a": f32[1, 2]}

# then `pytree` should be a pytree with the following structure:
eqx.tree_pprint(pytree)  # {"a": [f32[1, 2, 5, 9], f32[1, 2, 3]]}
__init__(self, pytree: PyTree[ArrayLike], output_structure: PyTree[jax.ShapeDtypeStruct], tags: Union[object, frozenset[object]] = ()) ¤

Arguments:

  • pytree: this should be a PyTree, with structure as specified in lineax.PyTreeLinearOperator.
  • output_structure: the structure of the output space. This should be a PyTree of jax.ShapeDtypeStructs. (The structure of the input space is then automatically derived from the structure of pytree.)
  • tags: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

lineax.JacobianLinearOperator (AbstractLinearOperator) ¤

Given a function fn: X -> Y, and a point x in X, then this defines the linear operator (also a function X -> Y) given by the Jacobian (d(fn)/dx)(x).

For example if the inputs and outputs are just arrays, then this is equivalent to MatrixLinearOperator(jax.jacfwd(fn)(x)).

The Jacobian is not materialised; matrix-vector products, which are in fact Jacobian-vector products, are computed using autodifferentiation, specifically jax.jvp. Thus, JacobianLinearOperator(fn, x).mv(v) is equivalent to jax.jvp(fn, (x,), (v,)).

See also lineax.linearise, which caches the primal computation, i.e. it returns _, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)

See also lineax.materialise, which materialises the whole Jacobian in memory.

__init__(self, fn: Callable, x: PyTree[ArrayLike], args: PyTree[typing.Any] = None, tags: Union[object, Iterable[object]] = (), jac: Optional[Literal['fwd', 'bwd']] = None) ¤

Arguments:

  • fn: A function (x, args) -> y. The Jacobian d(fn)/dx is used as the linear operator, and args are just any other arguments that should not be differentiated.
  • x: The point to evaluate d(fn)/dx at: (d(fn)/dx)(x, args).
  • args: As x; this is the point to evaluate d(fn)/dx at: (d(fn)/dx)(x, args).
  • tags: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.
  • jac: allows to use specific jacobian computation method. If jac=fwd forces jax.jacfwd to be used, similarly jac=bwd mandates the use of jax.jacrev. Otherwise, if not specified it will be chosen by default according to input and output shape.

lineax.FunctionLinearOperator (AbstractLinearOperator) ¤

Wraps a linear function fn: X -> Y into a linear operator. (So that self.mv(x) is defined by self.mv(x) == fn(x).)

See also lineax.materialise, which materialises the whole linear operator in memory. (Similar to .as_matrix().)

__init__(self, fn: Callable[[PyTree[Inexact[Array, '...']]], PyTree[Inexact[Array, '...']]], input_structure: PyTree[jax.ShapeDtypeStruct], tags: Union[object, Iterable[object]] = ()) ¤

Arguments:

  • fn: a linear function. Should accept a PyTree of floating-point JAX arrays, and return a PyTree of floating-point JAX arrays.
  • input_structure: A PyTree of jax.ShapeDtypeStructs specifying the structure of the input to the function. (When later calling self.mv(x) then this should match the structure of x, i.e. jax.eval_shape(lambda: x).)
  • tags: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

lineax.IdentityLinearOperator (AbstractLinearOperator) ¤

Represents the identity transformation X -> X, where each x in X is some PyTree of floating-point JAX arrays.

__init__(self, input_structure: PyTree[jax.ShapeDtypeStruct], output_structure: PyTree[jax.ShapeDtypeStruct] = sentinel) ¤

Arguments:

  • input_structure: A PyTree of jax.ShapeDtypeStructs specifying the structure of the the input space. (When later calling self.mv(x) then this should match the structure of x, i.e. jax.eval_shape(lambda: x).)
  • output_structure: A PyTree of jax.ShapeDtypeStructs specifying the structure of the the output space. If not passed then this defaults to the same as input_structure. If passed then it must have the same number of elements as input_structure, so that the operator is square.

lineax.TaggedLinearOperator (AbstractLinearOperator) ¤

Wraps another linear operator and specifies that it has certain tags, e.g. representing symmetry.

Example

# Some other operator.
operator = lx.MatrixLinearOperator(some_jax_array)

# Now symmetric! But the type system doesn't know this.
sym_operator = operator + operator.T
assert lx.is_symmetric(sym_operator) == False

# We can declare that our operator has a particular property.
sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag)
assert lx.is_symmetric(sym_operator) == True
__init__(self, operator: AbstractLinearOperator, tags: Union[object, Iterable[object]]) ¤

Arguments:

  • operator: some other linear operator to wrap.
  • tags: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.