# Linear operators¤

We often talk about solving a linear system $$Ax = b$$, where $$A \in \mathbb{R}^{n \times m}$$ is a matrix, $$b \in \mathbb{R}^n$$ is a vector, and $$x \in \mathbb{R}^m$$ is our desired solution.

The linear operators described on this page are ways of describing the matrix $$A$$. The simplest is lineax.MatrixLinearOperator, which simply holds the matrix $$A$$ directly.

Meanwhile if $$A$$ is diagonal, then there is also lineax.DiagonalLinearOperator: for efficiency this only stores the diagonal of $$A$$.

Or, perhaps we only have a function $$F : \mathbb{R}^m \to \mathbb{R}^n$$ such that $$F(x) = Ax$$. Whilst we could use $$F$$ to materialise the whole matrix $$A$$ and then store it in a lineax.MatrixLinearOperator, that may be very memory intensive. Instead, we may prefer to use lineax.FunctionLinearOperator. Many linear solvers (e.g. lineax.CG) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix $$A$$.

lineax.AbstractLinearOperator

####  lineax.AbstractLinearOperator ¤

Abstract base class for all linear operators.

Linear operators can act between PyTrees. Each AbstractLinearOperator is thought of as a linear function X -> Y, where each element of X is as PyTree of floating-point JAX arrays, and each element of Y is a PyTree of floating-point JAX arrays.

Abstract linear operators support some operations:

op1 + op2  # addition of two operators
op1 @ op2  # composition of two operators.
op1 * 3.2  # multiplication by a scalar
op1 / 3.2  # division by a scalar


##### mv(self, vector: PyTree[Inexact[Array, '_b']]) -> PyTree[Inexact[Array, '_a']] abstractmethod ¤

Computes a matrix-vector product between this operator and a vector.

Arguments:

• vector: Should be some PyTree of floating-point arrays, whose structure should match self.in_structure().

Returns:

A PyTree of floating-point arrays, with structure that matches self.out_structure().

##### as_matrix(self) -> Inexact[Array, 'a b'] abstractmethod ¤

Materialises this linear operator as a matrix.

Note that this can be a computationally (time and/or memory) expensive operation, as many linear operators are defined implicitly, e.g. in terms of their action on a vector.

Arguments: None.

Returns:

A 2-dimensional floating-point JAX array.

##### transpose(self) -> AbstractLinearOperator abstractmethod ¤

Transposes this linear operator.

This can be called as either operator.T or operator.transpose().

Arguments: None.

Returns:

##### in_structure(self) -> PyTree[jax.ShapeDtypeStruct] abstractmethod ¤

Returns the expected input structure of this linear operator.

Arguments: None.

Returns:

A PyTree of jax.ShapeDtypeStruct.

##### out_structure(self) -> PyTree[jax.ShapeDtypeStruct] abstractmethod ¤

Returns the expected output structure of this linear operator.

Arguments: None.

Returns:

A PyTree of jax.ShapeDtypeStruct.

##### in_size(self) -> int¤

Returns the total number of scalars in the input of this linear operator.

That is, the dimensionality of its input space.

Arguments: None.

Returns: An integer.

##### out_size(self) -> int¤

Returns the total number of scalars in the output of this linear operator.

That is, the dimensionality of its output space.

Arguments: None.

Returns: An integer.

####  lineax.MatrixLinearOperator (AbstractLinearOperator) ¤

Wraps a 2-dimensional JAX array into a linear operator.

If the matrix has shape (a, b) then matrix-vector multiplication (self.mv) is defined in the usual way: as performing a matrix-vector that accepts a vector of shape (a,) and returns a vector of shape (b,).

##### __init__(self, matrix: Shaped[Array, 'a b'], tags: Union[object, frozenset[object]] = ())¤

Arguments:

• matrix: a two-dimensional JAX array. For an array with shape (a, b) then this operator can perform matrix-vector products on a vector of shape (b,) to return a vector of shape (a,).
• tags: any tags indicating whether this matrix has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

####  lineax.DiagonalLinearOperator (AbstractLinearOperator) ¤

As lineax.MatrixLinearOperator, but for specifically a diagonal matrix.

Only the diagonal of the matrix is stored (for memory efficiency). Matrix-vector products are computed by doing a pointwise diagonal * vector, rather than a full matrix @ vector (for speed).

##### __init__(self, diagonal: Shaped[Array, 'size'])¤

Arguments:

• diagonal: A rank-one JAX array, i.e. of shape (a,) for some a. This is the diagonal of the matrix.

####  lineax.TridiagonalLinearOperator (AbstractLinearOperator) ¤

As lineax.MatrixLinearOperator, but for specifically a tridiagonal matrix.

##### __init__(self, diagonal: Inexact[Array, 'size'], lower_diagonal: Inexact[Array, 'size-1'], upper_diagonal: Inexact[Array, 'size-1'])¤

Arguments:

• diagonal: A rank-one JAX array. This is the diagonal of the matrix.
• lower_diagonal: A rank-one JAX array. This is the lower diagonal of the matrix.
• upper_diagonal: A rank-one JAX array. This is the upper diagonal of the matrix.

If diagonal has shape (a,) then lower_diagonal and upper_diagonal should both have shape (a - 1,).

####  lineax.PyTreeLinearOperator (AbstractLinearOperator) ¤

Represents a PyTree of floating-point JAX arrays as a linear operator.

This is basically a generalisation of lineax.MatrixLinearOperator, from taking just a single array to take a PyTree-of-arrays. (And likewise from returning a single array to returning a PyTree-of-arrays.)

Specifically, suppose we want this to be a linear operator X -> Y, for which elements of X are PyTrees with structure T whose ith leaf is a floating-point JAX array of shape x_shape_i, and elements of Y are PyTrees with structure S whose jth leaf is a floating-point JAX array of has shape y_shape_j. Then the input PyTree should have structure T-compose-S, and its (i, j)-th leaf should be a floating-point JAX array of shape (*x_shape_i, *y_shape_j).

Example

# Suppose x is a member of our input space, with the following pytree
# structure:
eqx.tree_pprint(x)  # [f32[5, 9], f32[3]]

# Suppose y is a member of our output space, with the following pytree
# structure:
eqx.tree_pprint(y)
# {"a": f32[1, 2]}

# then pytree should be a pytree with the following structure:
eqx.tree_pprint(pytree)  # {"a": [f32[1, 2, 5, 9], f32[1, 2, 3]]}

##### __init__(self, pytree: PyTree[ArrayLike], output_structure: PyTree[jax.ShapeDtypeStruct], tags: Union[object, frozenset[object]] = ())¤

Arguments:

• pytree: this should be a PyTree, with structure as specified in lineax.PyTreeLinearOperator.
• output_structure: the structure of the output space. This should be a PyTree of jax.ShapeDtypeStructs. (The structure of the input space is then automatically derived from the structure of pytree.)
• tags: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

####  lineax.JacobianLinearOperator (AbstractLinearOperator) ¤

Given a function fn: X -> Y, and a point x in X, then this defines the linear operator (also a function X -> Y) given by the Jacobian (d(fn)/dx)(x).

For example if the inputs and outputs are just arrays, then this is equivalent to MatrixLinearOperator(jax.jacfwd(fn)(x)).

The Jacobian is not materialised; matrix-vector products, which are in fact Jacobian-vector products, are computed using autodifferentiation, specifically jax.jvp. Thus, JacobianLinearOperator(fn, x).mv(v) is equivalent to jax.jvp(fn, (x,), (v,)).

See also lineax.linearise, which caches the primal computation, i.e. it returns _, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)

See also lineax.materialise, which materialises the whole Jacobian in memory.

##### __init__(self, fn: Callable, x: PyTree[ArrayLike], args: PyTree[typing.Any] = None, tags: Union[object, Iterable[object]] = (), jac: Optional[Literal['fwd', 'bwd']] = None)¤

Arguments:

• fn: A function (x, args) -> y. The Jacobian d(fn)/dx is used as the linear operator, and args are just any other arguments that should not be differentiated.
• x: The point to evaluate d(fn)/dx at: (d(fn)/dx)(x, args).
• args: As x; this is the point to evaluate d(fn)/dx at: (d(fn)/dx)(x, args).
• tags: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.
• jac: allows to use specific jacobian computation method. If jac=fwd forces jax.jacfwd to be used, similarly jac=bwd mandates the use of jax.jacrev. Otherwise, if not specified it will be chosen by default according to input and output shape.

####  lineax.FunctionLinearOperator (AbstractLinearOperator) ¤

Wraps a linear function fn: X -> Y into a linear operator. (So that self.mv(x) is defined by self.mv(x) == fn(x).)

See also lineax.materialise, which materialises the whole linear operator in memory. (Similar to .as_matrix().)

##### __init__(self, fn: Callable[[PyTree[Inexact[Array, '...']]], PyTree[Inexact[Array, '...']]], input_structure: PyTree[jax.ShapeDtypeStruct], tags: Union[object, Iterable[object]] = ())¤

Arguments:

• fn: a linear function. Should accept a PyTree of floating-point JAX arrays, and return a PyTree of floating-point JAX arrays.
• input_structure: A PyTree of jax.ShapeDtypeStructs specifying the structure of the input to the function. (When later calling self.mv(x) then this should match the structure of x, i.e. jax.eval_shape(lambda: x).)
• tags: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

####  lineax.IdentityLinearOperator (AbstractLinearOperator) ¤

Represents the identity transformation X -> X, where each x in X is some PyTree of floating-point JAX arrays.

##### __init__(self, input_structure: PyTree[jax.ShapeDtypeStruct], output_structure: PyTree[jax.ShapeDtypeStruct] = sentinel)¤

Arguments:

• input_structure: A PyTree of jax.ShapeDtypeStructs specifying the structure of the the input space. (When later calling self.mv(x) then this should match the structure of x, i.e. jax.eval_shape(lambda: x).)
• output_structure: A PyTree of jax.ShapeDtypeStructs specifying the structure of the the output space. If not passed then this defaults to the same as input_structure. If passed then it must have the same number of elements as input_structure, so that the operator is square.

####  lineax.TaggedLinearOperator (AbstractLinearOperator) ¤

Wraps another linear operator and specifies that it has certain tags, e.g. representing symmetry.

Example

# Some other operator.
operator = lx.MatrixLinearOperator(some_jax_array)

# Now symmetric! But the type system doesn't know this.
sym_operator = operator + operator.T
assert lx.is_symmetric(sym_operator) == False

# We can declare that our operator has a particular property.
sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag)
assert lx.is_symmetric(sym_operator) == True

##### __init__(self, operator: AbstractLinearOperator, tags: Union[object, Iterable[object]])¤

Arguments:

• operator: some other linear operator to wrap.
• tags: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.