Skip to content

Functions on linear operators¤

We define a number of functions on linear operators.

Computational changes¤

These do not change the mathematical meaning of the operator; they simply change how it is stored computationally. (E.g. to materialise the whole operator.)

lineax.linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator ¤

Linearises a linear operator. This returns another linear operator.

Mathematically speaking this is just the identity function. And indeed most linear operators will be returned unchanged.

For specifically lineax.JacobianLinearOperator, then this will cache the primal pass, so that it does not need to be recomputed each time. That is, it uses some memory to improve speed. (This is the precisely same distinction as jax.jvp versus jax.linearize.)

Arguments:

  • operator: a linear operator.

Returns:

Another linear operator. Mathematically it performs matrix-vector products (operator.mv) that produce the same results as the input operator.


lineax.materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator ¤

Materialises a linear operator. This returns another linear operator.

Mathematically speaking this is just the identity function. And indeed most linear operators will be returned unchanged.

For specifically lineax.JacobianLinearOperator and lineax.FunctionLinearOperator then the linear operator is materialised in memory. That is, it becomes defined as a matrix (or pytree of arrays), rather than being defined only through its matrix-vector product (lineax.AbstractLinearOperator.mv).

Materialisation sometimes improves compile time or run time. It usually increases memory usage.

For example:

large_function = ...
operator = lx.FunctionLinearOperator(large_function, ...)

# Option 1
out1 = operator.mv(vector1)  # Traces and compiles `large_function`
out2 = operator.mv(vector2)  # Traces and compiles `large_function` again!
out3 = operator.mv(vector3)  # Traces and compiles `large_function` a third time!
# All that compilation might lead to long compile times.
# If `large_function` takes a long time to run, then this might also lead to long
# run times.

# Option 2
operator = lx.materialise(operator)  # Traces and compiles `large_function` and
                                       # stores the result as a matrix.
out1 = operator.mv(vector1)  # Each of these just computes a matrix-vector product
out2 = operator.mv(vector2)  # against the stored matrix.
out3 = operator.mv(vector3)  #
# Now, `large_function` is only compiled once, and only ran once.
# However, storing the matrix might take a lot of memory, and the initial
# computation may-or-may-not take a long time to run.
Generally speaking it is worth first setting up your problem without lx.materialise, and using it as an optional optimisation if you find that it helps your particular problem.

Arguments:

  • operator: a linear operator.

Returns:

Another linear operator. Mathematically it performs matrix-vector products (operator.mv) that produce the same results as the input operator.

Extract information from the operator¤

lineax.diagonal(operator: AbstractLinearOperator) -> Shaped[Array, 'size'] ¤

Extracts the diagonal from a linear operator, and returns a vector.

Arguments:

  • operator: a linear operator.

Returns:

A rank-1 JAX array. (That is, it has shape (a,) for some integer a.)

For most operators this is just jnp.diag(operator.as_matrix()). Some operators (e.g. lineax.DiagonalLinearOperator) can have more efficient implementations. If you don't know what kind of operator you might have, then this function ensures that you always get the most efficient implementation.


lineax.tridiagonal(operator: AbstractLinearOperator) -> tuple[Shaped[Array, 'size'], Shaped[Array, 'size-1'], Shaped[Array, 'size-1']] ¤

Extracts the diagonal, lower diagonal, and upper diagonal, from a linear operator. Returns three vectors.

Arguments:

  • operator: a linear operator.

Returns:

A 3-tuple, consisting of:

  • The diagonal of the matrix, represented as a vector.
  • The lower diagonal of the matrix, represented as a vector.
  • The upper diagonal of the matrix, represented as a vector.

If the diagonal has shape (a,) then the lower and upper diagonals will have shape (a - 1,).

For most operators these are computed by materialising the array and then extracting the relevant elements, e.g. getting the main diagonal via jnp.diag(operator.as_matrix()). Some operators (e.g. lineax.TridiagonalLinearOperator) can have more efficient implementations. If you don't know what kind of operator you might have, then this function ensures that you always get the most efficient implementation.

Test the operator to see if it exhibits a certain property¤

Note that these do not inspect the values of the operator -- instead, they use typically use tags. (Or in some cases, just the type of the operator: e.g. is_diagonal(DiagonalLinearOperator(...)) == True.)

lineax.has_unit_diagonal(operator: AbstractLinearOperator) -> bool ¤

Returns whether an operator is marked as having unit diagonal.

See the documentation on linear operator tags for more information.

Arguments:

  • operator: a linear operator.

Returns:

Either True or False.


lineax.is_diagonal(operator: AbstractLinearOperator) -> bool ¤

Returns whether an operator is marked as diagonal.

See the documentation on linear operator tags for more information.

Arguments:

  • operator: a linear operator.

Returns:

Either True or False.


lineax.is_tridiagonal(operator: AbstractLinearOperator) -> bool ¤

Returns whether an operator is marked as tridiagonal.

See the documentation on linear operator tags for more information.

Arguments:

  • operator: a linear operator.

Returns:

Either True or False.


lineax.is_lower_triangular(operator: AbstractLinearOperator) -> bool ¤

Returns whether an operator is marked as lower triangular.

See the documentation on linear operator tags for more information.

Arguments:

  • operator: a linear operator.

Returns:

Either True or False.


lineax.is_upper_triangular(operator: AbstractLinearOperator) -> bool ¤

Returns whether an operator is marked as upper triangular.

See the documentation on linear operator tags for more information.

Arguments:

  • operator: a linear operator.

Returns:

Either True or False.


lineax.is_positive_semidefinite(operator: AbstractLinearOperator) -> bool ¤

Returns whether an operator is marked as positive semidefinite.

See the documentation on linear operator tags for more information.

Arguments:

  • operator: a linear operator.

Returns:

Either True or False.


lineax.is_negative_semidefinite(operator: AbstractLinearOperator) -> bool ¤

Returns whether an operator is marked as negative semidefinite.

See the documentation on linear operator tags for more information.

Arguments:

  • operator: a linear operator.

Returns:

Either True or False.


lineax.is_symmetric(operator: AbstractLinearOperator) -> bool ¤

Returns whether an operator is marked as symmetric.

See the documentation on linear operator tags for more information.

Arguments:

  • operator: a linear operator.

Returns:

Either True or False.