Linear operators¤
We often talk about solving a linear system \(Ax = b\), where \(A \in \mathbb{R}^{n \times m}\) is a matrix, \(b \in \mathbb{R}^n\) is a vector, and \(x \in \mathbb{R}^m\) is our desired solution.
The linear operators described on this page are ways of describing the matrix \(A\). The simplest is lineax.MatrixLinearOperator
, which simply holds the matrix \(A\) directly.
Meanwhile if \(A\) is diagonal, then there is also lineax.DiagonalLinearOperator
: for efficiency this only stores the diagonal of \(A\).
Or, perhaps we only have a function \(F : \mathbb{R}^m \to \mathbb{R}^n\) such that \(F(x) = Ax\). Whilst we could use \(F\) to materialise the whole matrix \(A\) and then store it in a lineax.MatrixLinearOperator
, that may be very memory intensive. Instead, we may prefer to use lineax.FunctionLinearOperator
. Many linear solvers (e.g. lineax.CG
) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix \(A\).
lineax.AbstractLinearOperator
lineax.AbstractLinearOperator
¤
Abstract base class for all linear operators.
Linear operators can act between PyTrees. Each AbstractLinearOperator
is thought
of as a linear function X -> Y
, where each element of X
is as PyTree of
floating-point JAX arrays, and each element of Y
is a PyTree of floating-point
JAX arrays.
Abstract linear operators support some operations:
op1 + op2 # addition of two operators
op1 @ op2 # composition of two operators.
op1 * 3.2 # multiplication by a scalar
op1 / 3.2 # division by a scalar
mv(self, vector: PyTree[Inexact[Array, '_b']]) -> PyTree[Inexact[Array, '_a']]
abstractmethod
¤
Computes a matrix-vector product between this operator and a vector
.
Arguments:
vector
: Should be some PyTree of floating-point arrays, whose structure should matchself.in_structure()
.
Returns:
A PyTree of floating-point arrays, with structure that matches
self.out_structure()
.
as_matrix(self) -> Inexact[Array, 'a b']
abstractmethod
¤
Materialises this linear operator as a matrix.
Note that this can be a computationally (time and/or memory) expensive operation, as many linear operators are defined implicitly, e.g. in terms of their action on a vector.
Arguments: None.
Returns:
A 2-dimensional floating-point JAX array.
transpose(self) -> AbstractLinearOperator
abstractmethod
¤
Transposes this linear operator.
This can be called as either operator.T
or operator.transpose()
.
Arguments: None.
Returns:
Another lineax.AbstractLinearOperator
.
in_structure(self) -> PyTree[jax.ShapeDtypeStruct]
abstractmethod
¤
Returns the expected input structure of this linear operator.
Arguments: None.
Returns:
A PyTree of jax.ShapeDtypeStruct
.
out_structure(self) -> PyTree[jax.ShapeDtypeStruct]
abstractmethod
¤
Returns the expected output structure of this linear operator.
Arguments: None.
Returns:
A PyTree of jax.ShapeDtypeStruct
.
in_size(self) -> int
¤
Returns the total number of scalars in the input of this linear operator.
That is, the dimensionality of its input space.
Arguments: None.
Returns: An integer.
out_size(self) -> int
¤
Returns the total number of scalars in the output of this linear operator.
That is, the dimensionality of its output space.
Arguments: None.
Returns: An integer.
¤
¤
¤
¤
¤
¤
¤
¤
lineax.MatrixLinearOperator (AbstractLinearOperator)
¤
Wraps a 2-dimensional JAX array into a linear operator.
If the matrix has shape (a, b)
then matrix-vector multiplication (self.mv
) is
defined in the usual way: as performing a matrix-vector that accepts a vector of
shape (a,)
and returns a vector of shape (b,)
.
__init__(self, matrix: Shaped[Array, 'a b'], tags: Union[object, frozenset[object]] = ())
¤
Arguments:
matrix
: a two-dimensional JAX array. For an array with shape(a, b)
then this operator can perform matrix-vector products on a vector of shape(b,)
to return a vector of shape(a,)
.tags
: any tags indicating whether this matrix has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.
lineax.DiagonalLinearOperator (AbstractLinearOperator)
¤
As lineax.MatrixLinearOperator
, but for specifically a diagonal matrix.
Only the diagonal of the matrix is stored (for memory efficiency). Matrix-vector
products are computed by doing a pointwise diagonal * vector
, rather than a full
matrix @ vector
(for speed).
__init__(self, diagonal: Shaped[Array, 'size'])
¤
Arguments:
diagonal
: A rank-one JAX array, i.e. of shape(a,)
for somea
. This is the diagonal of the matrix.
lineax.TridiagonalLinearOperator (AbstractLinearOperator)
¤
As lineax.MatrixLinearOperator
, but for specifically a tridiagonal
matrix.
__init__(self, diagonal: Inexact[Array, 'size'], lower_diagonal: Inexact[Array, 'size-1'], upper_diagonal: Inexact[Array, 'size-1'])
¤
Arguments:
diagonal
: A rank-one JAX array. This is the diagonal of the matrix.lower_diagonal
: A rank-one JAX array. This is the lower diagonal of the matrix.upper_diagonal
: A rank-one JAX array. This is the upper diagonal of the matrix.
If diagonal
has shape (a,)
then lower_diagonal
and upper_diagonal
should
both have shape (a - 1,)
.
lineax.PyTreeLinearOperator (AbstractLinearOperator)
¤
Represents a PyTree of floating-point JAX arrays as a linear operator.
This is basically a generalisation of lineax.MatrixLinearOperator
, from
taking just a single array to take a PyTree-of-arrays. (And likewise from returning
a single array to returning a PyTree-of-arrays.)
Specifically, suppose we want this to be a linear operator X -> Y
, for which
elements of X
are PyTrees with structure T
whose i
th leaf is a floating-point
JAX array of shape x_shape_i
, and elements of Y
are PyTrees with structure S
whose j
th leaf is a floating-point JAX array of has shape y_shape_j
. Then the
input PyTree should have structure T
-compose-S
, and its (i, j)
-th leaf should
be a floating-point JAX array of shape (*x_shape_i, *y_shape_j)
.
Example
# Suppose `x` is a member of our input space, with the following pytree
# structure:
eqx.tree_pprint(x) # [f32[5, 9], f32[3]]
# Suppose `y` is a member of our output space, with the following pytree
# structure:
eqx.tree_pprint(y)
# {"a": f32[1, 2]}
# then `pytree` should be a pytree with the following structure:
eqx.tree_pprint(pytree) # {"a": [f32[1, 2, 5, 9], f32[1, 2, 3]]}
__init__(self, pytree: PyTree[ArrayLike], output_structure: PyTree[jax.ShapeDtypeStruct], tags: Union[object, frozenset[object]] = ())
¤
Arguments:
pytree
: this should be a PyTree, with structure as specified inlineax.PyTreeLinearOperator
.output_structure
: the structure of the output space. This should be a PyTree ofjax.ShapeDtypeStruct
s. (The structure of the input space is then automatically derived from the structure ofpytree
.)tags
: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.
lineax.JacobianLinearOperator (AbstractLinearOperator)
¤
Given a function fn: X -> Y
, and a point x in X
, then this defines the
linear operator (also a function X -> Y
) given by the Jacobian (d(fn)/dx)(x)
.
For example if the inputs and outputs are just arrays, then this is equivalent to
MatrixLinearOperator(jax.jacfwd(fn)(x))
.
The Jacobian is not materialised; matrix-vector products, which are in fact
Jacobian-vector products, are computed using autodifferentiation, specifically
jax.jvp
. Thus, JacobianLinearOperator(fn, x).mv(v)
is equivalent to
jax.jvp(fn, (x,), (v,))
.
See also lineax.linearise
, which caches the primal computation, i.e.
it returns _, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)
See also lineax.materialise
, which materialises the whole Jacobian in
memory.
__init__(self, fn: Callable, x: PyTree[ArrayLike], args: PyTree[typing.Any] = None, tags: Union[object, Iterable[object]] = (), jac: Optional[Literal['fwd', 'bwd']] = None)
¤
Arguments:
fn
: A function(x, args) -> y
. The Jacobiand(fn)/dx
is used as the linear operator, andargs
are just any other arguments that should not be differentiated.x
: The point to evaluated(fn)/dx
at:(d(fn)/dx)(x, args)
.args
: Asx
; this is the point to evaluated(fn)/dx
at:(d(fn)/dx)(x, args)
.tags
: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.jac
: allows to use specific jacobian computation method. Ifjac=fwd
forcesjax.jacfwd
to be used, similarlyjac=bwd
mandates the use ofjax.jacrev
. Otherwise, if not specified it will be chosen by default according to input and output shape.
lineax.FunctionLinearOperator (AbstractLinearOperator)
¤
Wraps a linear function fn: X -> Y
into a linear operator. (So that
self.mv(x)
is defined by self.mv(x) == fn(x)
.)
See also lineax.materialise
, which materialises the whole linear operator
in memory. (Similar to .as_matrix()
.)
__init__(self, fn: Callable[[PyTree[Inexact[Array, '...']]], PyTree[Inexact[Array, '...']]], input_structure: PyTree[jax.ShapeDtypeStruct], tags: Union[object, Iterable[object]] = ())
¤
Arguments:
fn
: a linear function. Should accept a PyTree of floating-point JAX arrays, and return a PyTree of floating-point JAX arrays.input_structure
: A PyTree ofjax.ShapeDtypeStruct
s specifying the structure of the input to the function. (When later callingself.mv(x)
then this should match the structure ofx
, i.e.jax.eval_shape(lambda: x)
.)tags
: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.
lineax.IdentityLinearOperator (AbstractLinearOperator)
¤
Represents the identity transformation X -> X
, where each x in X
is some
PyTree of floating-point JAX arrays.
__init__(self, input_structure: PyTree[jax.ShapeDtypeStruct], output_structure: PyTree[jax.ShapeDtypeStruct] = sentinel)
¤
Arguments:
input_structure
: A PyTree ofjax.ShapeDtypeStruct
s specifying the structure of the the input space. (When later callingself.mv(x)
then this should match the structure ofx
, i.e.jax.eval_shape(lambda: x)
.)output_structure
: A PyTree ofjax.ShapeDtypeStruct
s specifying the structure of the the output space. If not passed then this defaults to the same asinput_structure
. If passed then it must have the same number of elements asinput_structure
, so that the operator is square.
lineax.TaggedLinearOperator (AbstractLinearOperator)
¤
Wraps another linear operator and specifies that it has certain tags, e.g. representing symmetry.
Example
# Some other operator.
operator = lx.MatrixLinearOperator(some_jax_array)
# Now symmetric! But the type system doesn't know this.
sym_operator = operator + operator.T
assert lx.is_symmetric(sym_operator) == False
# We can declare that our operator has a particular property.
sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag)
assert lx.is_symmetric(sym_operator) == True
__init__(self, operator: AbstractLinearOperator, tags: Union[object, Iterable[object]])
¤
Arguments:
operator
: some other linear operator to wrap.tags
: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong.