# Linear operators¤

We often talk about solving a linear system \(Ax = b\), where \(A \in \mathbb{R}^{n \times m}\) is a matrix, \(b \in \mathbb{R}^n\) is a vector, and \(x \in \mathbb{R}^m\) is our desired solution.

The linear operators described on this page are ways of describing the matrix \(A\). The simplest is `lineax.MatrixLinearOperator`

, which simply holds the matrix \(A\) directly.

Meanwhile if \(A\) is diagonal, then there is also `lineax.DiagonalLinearOperator`

: for efficiency this only stores the diagonal of \(A\).

Or, perhaps we only have a function \(F : \mathbb{R}^m \to \mathbb{R}^n\) such that \(F(x) = Ax\). Whilst we could use \(F\) to materialise the whole matrix \(A\) and then store it in a `lineax.MatrixLinearOperator`

, that may be very memory intensive. Instead, we may prefer to use `lineax.FunctionLinearOperator`

. Many linear solvers (e.g. `lineax.CG`

) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix \(A\).

`lineax.AbstractLinearOperator`

####
```
lineax.AbstractLinearOperator
```

¤

Abstract base class for all linear operators.

Linear operators can act between PyTrees. Each `AbstractLinearOperator`

is thought
of as a linear function `X -> Y`

, where each element of `X`

is as PyTree of
floating-point JAX arrays, and each element of `Y`

is a PyTree of floating-point
JAX arrays.

Abstract linear operators support some operations:

```
op1 + op2 # addition of two operators
op1 @ op2 # composition of two operators.
op1 * 3.2 # multiplication by a scalar
op1 / 3.2 # division by a scalar
```

#####
`mv(self, vector: PyTree[Inexact[Array, '_b']]) -> PyTree[Inexact[Array, '_a']]`

`abstractmethod`

¤

Computes a matrix-vector product between this operator and a `vector`

.

**Arguments:**

`vector`

: Should be some PyTree of floating-point arrays, whose structure should match`self.in_structure()`

.

**Returns:**

A PyTree of floating-point arrays, with structure that matches
`self.out_structure()`

.

#####
`as_matrix(self) -> Inexact[Array, 'a b']`

`abstractmethod`

¤

Materialises this linear operator as a matrix.

Note that this can be a computationally (time and/or memory) expensive operation, as many linear operators are defined implicitly, e.g. in terms of their action on a vector.

**Arguments:** None.

**Returns:**

A 2-dimensional floating-point JAX array.

#####
`transpose(self) -> AbstractLinearOperator`

`abstractmethod`

¤

Transposes this linear operator.

This can be called as either `operator.T`

or `operator.transpose()`

.

**Arguments:** None.

**Returns:**

Another `lineax.AbstractLinearOperator`

.

#####
`in_structure(self) -> PyTree[jax.ShapeDtypeStruct]`

`abstractmethod`

¤

Returns the expected input structure of this linear operator.

**Arguments:** None.

**Returns:**

A PyTree of `jax.ShapeDtypeStruct`

.

#####
`out_structure(self) -> PyTree[jax.ShapeDtypeStruct]`

`abstractmethod`

¤

Returns the expected output structure of this linear operator.

**Arguments:** None.

**Returns:**

A PyTree of `jax.ShapeDtypeStruct`

.

#####
`in_size(self) -> int`

¤

Returns the total number of scalars in the input of this linear operator.

That is, the dimensionality of its input space.

**Arguments:** None.

**Returns:** An integer.

#####
`out_size(self) -> int`

¤

Returns the total number of scalars in the output of this linear operator.

That is, the dimensionality of its output space.

**Arguments:** None.

**Returns:** An integer.

#### ¤

##### ¤

##### ¤

##### ¤

##### ¤

##### ¤

##### ¤

##### ¤

####
```
lineax.MatrixLinearOperator (AbstractLinearOperator)
```

¤

Wraps a 2-dimensional JAX array into a linear operator.

If the matrix has shape `(a, b)`

then matrix-vector multiplication (`self.mv`

) is
defined in the usual way: as performing a matrix-vector that accepts a vector of
shape `(a,)`

and returns a vector of shape `(b,)`

.

#####
`__init__(self, matrix: Shaped[Array, 'a b'], tags: Union[object, frozenset[object]] = ())`

¤

**Arguments:**

`matrix`

: a two-dimensional JAX array. For an array with shape`(a, b)`

then this operator can perform matrix-vector products on a vector of shape`(b,)`

to return a vector of shape`(a,)`

.`tags`

: any tags indicating whether this matrix has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

####
```
lineax.DiagonalLinearOperator (AbstractLinearOperator)
```

¤

As `lineax.MatrixLinearOperator`

, but for specifically a diagonal matrix.

Only the diagonal of the matrix is stored (for memory efficiency). Matrix-vector
products are computed by doing a pointwise `diagonal * vector`

, rather than a full
`matrix @ vector`

(for speed).

#####
`__init__(self, diagonal: Shaped[Array, 'size'])`

¤

**Arguments:**

`diagonal`

: A rank-one JAX array, i.e. of shape`(a,)`

for some`a`

. This is the diagonal of the matrix.

####
```
lineax.TridiagonalLinearOperator (AbstractLinearOperator)
```

¤

As `lineax.MatrixLinearOperator`

, but for specifically a tridiagonal
matrix.

#####
`__init__(self, diagonal: Inexact[Array, 'size'], lower_diagonal: Inexact[Array, 'size-1'], upper_diagonal: Inexact[Array, 'size-1'])`

¤

**Arguments:**

`diagonal`

: A rank-one JAX array. This is the diagonal of the matrix.`lower_diagonal`

: A rank-one JAX array. This is the lower diagonal of the matrix.`upper_diagonal`

: A rank-one JAX array. This is the upper diagonal of the matrix.

If `diagonal`

has shape `(a,)`

then `lower_diagonal`

and `upper_diagonal`

should
both have shape `(a - 1,)`

.

####
```
lineax.PyTreeLinearOperator (AbstractLinearOperator)
```

¤

Represents a PyTree of floating-point JAX arrays as a linear operator.

This is basically a generalisation of `lineax.MatrixLinearOperator`

, from
taking just a single array to take a PyTree-of-arrays. (And likewise from returning
a single array to returning a PyTree-of-arrays.)

Specifically, suppose we want this to be a linear operator `X -> Y`

, for which
elements of `X`

are PyTrees with structure `T`

whose `i`

th leaf is a floating-point
JAX array of shape `x_shape_i`

, and elements of `Y`

are PyTrees with structure `S`

whose `j`

th leaf is a floating-point JAX array of has shape `y_shape_j`

. Then the
input PyTree should have structure `T`

-compose-`S`

, and its `(i, j)`

-th leaf should
be a floating-point JAX array of shape `(*x_shape_i, *y_shape_j)`

.

Example

```
# Suppose `x` is a member of our input space, with the following pytree
# structure:
eqx.tree_pprint(x) # [f32[5, 9], f32[3]]
# Suppose `y` is a member of our output space, with the following pytree
# structure:
eqx.tree_pprint(y)
# {"a": f32[1, 2]}
# then `pytree` should be a pytree with the following structure:
eqx.tree_pprint(pytree) # {"a": [f32[1, 2, 5, 9], f32[1, 2, 3]]}
```

#####
`__init__(self, pytree: PyTree[ArrayLike], output_structure: PyTree[jax.ShapeDtypeStruct], tags: Union[object, frozenset[object]] = ())`

¤

**Arguments:**

`pytree`

: this should be a PyTree, with structure as specified in`lineax.PyTreeLinearOperator`

.`out_structure`

: the structure of the output space. This should be a PyTree of`jax.ShapeDtypeStruct`

s. (The structure of the input space is then automatically derived from the structure of`pytree`

.)`tags`

: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

####
```
lineax.JacobianLinearOperator (AbstractLinearOperator)
```

¤

Given a function `fn: X -> Y`

, and a point `x in X`

, then this defines the
linear operator (also a function `X -> Y`

) given by the Jacobian `(d(fn)/dx)(x)`

.

For example if the inputs and outputs are just arrays, then this is equivalent to
`MatrixLinearOperator(jax.jacfwd(fn)(x))`

.

The Jacobian is not materialised; matrix-vector products, which are in fact
Jacobian-vector products, are computed using autodifferentiation, specifically
`jax.jvp`

. Thus, `JacobianLinearOperator(fn, x).mv(v)`

is equivalent to
`jax.jvp(fn, (x,), (v,))`

.

See also `lineax.linearise`

, which caches the primal computation, i.e.
it returns `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`

See also `lineax.materialise`

, which materialises the whole Jacobian in
memory.

#####
`__init__(self, fn: Callable, x: PyTree[ArrayLike], args: PyTree[typing.Any] = None, tags: Union[object, Iterable[object]] = (), _has_aux: bool = False)`

¤

**Arguments:**

`fn`

: A function`(x, args) -> y`

. The Jacobian`d(fn)/dx`

is used as the linear operator, and`args`

are just any other arguments that should not be differentiated.`x`

: The point to evaluate`d(fn)/dx`

at:`(d(fn)/dx)(x, args)`

.`args`

: As`x`

; this is the point to evaluate`d(fn)/dx`

at:`(d(fn)/dx)(x, args)`

.`tags`

: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

####
```
lineax.FunctionLinearOperator (AbstractLinearOperator)
```

¤

Wraps a *linear* function `fn: X -> Y`

into a linear operator. (So that
`self.mv(x)`

is defined by `self.mv(x) == fn(x)`

.)

See also `lineax.materialise`

, which materialises the whole linear operator
in memory. (Similar to `.as_matrix()`

.)

#####
`__init__(self, fn: collections.abc.Callable[[PyTree[Inexact[Array, '...']]], PyTree[Inexact[Array, '...']]], input_structure: PyTree[jax.ShapeDtypeStruct], tags: Union[object, Iterable[object]] = ())`

¤

**Arguments:**

`fn`

: a linear function. Should accept a PyTree of floating-point JAX arrays, and return a PyTree of floating-point JAX arrays.`input_structure`

: A PyTree of`jax.ShapeDtypeStruct`

s specifying the structure of the input to the function. (When later calling`self.mv(x)`

then this should match the structure of`x`

, i.e.`jax.eval_shape(lambda: x)`

.)`tags`

: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.

####
```
lineax.IdentityLinearOperator (AbstractLinearOperator)
```

¤

Represents the identity transformation `X -> X`

, where each `x in X`

is some
PyTree of floating-point JAX arrays.

#####
`__init__(self, input_structure: PyTree[jax.ShapeDtypeStruct], output_structure: PyTree[jax.ShapeDtypeStruct] = sentinel)`

¤

**Arguments:**

`input_structure`

: A PyTree of`jax.ShapeDtypeStruct`

s specifying the structure of the the input space. (When later calling`self.mv(x)`

then this should match the structure of`x`

, i.e.`jax.eval_shape(lambda: x)`

.)`output_structure`

: A PyTree of`jax.ShapeDtypeStruct`

s specifying the structure of the the output space. If not passed then this defaults to the same as`input_structure`

. If passed then it must have the same number of elements as`input_structure`

, so that the operator is square.

####
```
lineax.TaggedLinearOperator (AbstractLinearOperator)
```

¤

Wraps another linear operator and specifies that it has certain tags, e.g. representing symmetry.

Example

```
# Some other operator.
operator = lx.MatrixLinearOperator(some_jax_array)
# Now symmetric! But the type system doesn't know this.
sym_operator = operator + operator.T
assert lx.is_symmetric(sym_operator) == False
# We can declare that our operator has a particular property.
sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag)
assert lx.is_symmetric(sym_operator) == True
```

#####
`__init__(self, operator: AbstractLinearOperator, tags: Union[object, Iterable[object]])`

¤

**Arguments:**

`operator`

: some other linear operator to wrap.`tags`

: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.