Skip to content

Terms¤

One of the advanced features of Diffrax is its term system. When we write down e.g. a stochastic differential equation

\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)

then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a vector field (\(f\) or \(g\)) and a control (\(\mathrm{d}t\) or \(\mathrm{d}w(t)\)). In addition (often not represented in mathematical notation), there is also a choice of how the vector field and control interact: \(f\) and \(\mathrm{d}t\) interact as a vector-scalar product. \(g\) and \(\mathrm{d}w(t)\) interact as a matrix-vector product. (In general this interaction is always bilinear.)

"Terms" are thus the building blocks of differential equations.

Example

Consider the ODE \(\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))\). Then this has vector field \(f\), control \(\mathrm{d}t\), and their interaction is a vector-scalar product. This can be described as a single diffrax.ODETerm.

Adding multiple terms, such as SDEs¤

We can add multiple terms together by grouping them into a single diffrax.MultiTerm.

Example

The SDE above would have its drift described by diffrax.ODETerm and the diffusion described by a diffrax.ControlTerm. As these affect the same evolving state variable, they should be passed to the solver as MultiTerm(ODETerm(...), ControlTerm(...)).

Independent terms, such as Hamiltonian systems¤

If terms affect different pieces of the state, then they should be placed in some PyTree structure.

Example

Consider the pair of equations (as commonly arising from Hamiltonian systems):

\(\frac{\mathrm{d}x}{\mathrm{d}t}(t) = f(t, y(t)),\qquad\frac{\mathrm{d}y}{\mathrm{d}t}(t) = g(t, x(t))\)

These would be passed to the solver as the 2-tuple of (ODETerm(...), ODETerm(...)).

What each solver accepts¤

Each solver in Diffrax will specify what kinds of problems it can handle, as described by their .term_structure attribute. Not all solvers are able to handle all problems!

Some example term structures include:

  1. solver.term_structure = AbstractTerm

    In this case the solver can handle a simple ODE as descibed above: ODETerm is a subclass of AbstractTerm.

    It can also handle SDEs: MultiTerm(ODETerm(...), ControlTerm(...)) includes everything wrapped into a single term (the MultiTerm), and at that point this defines an interface the solver knows how to handle.

    Most solvers in Diffrax have this term structure.

  2. solver.term_structure = MultiTerm[tuple[ODETerm, ControlTerm]]

    In this case the solver specifically handles just SDEs of the form MultiTerm(ODETerm(...), ControlTerm(...)); nothing else is compatible.

    Some SDE-specific solvers have this term structure.

  3. solver.term_structure = (AbstractTerm, AbstractTerm)

    In this case the solver is used to solve ODEs like the Hamiltonian system described above: we have a PyTree of terms, each of which is treated individually.


diffrax.AbstractTerm

diffrax.AbstractTerm ¤

Abstract base class for all terms.

Let \(y\) solve some differential equation with vector field \(f\) and control \(x\).

Let \(y\) have PyTree structure \(T\), let the output of the vector field have PyTree structure \(S\), and let \(x\) have PyTree structure \(U\), Then \(f : T \to S\) whilst the interaction \((f, x) \mapsto f \mathrm{d}x\) is a function \((S, U) \to T\).

vf(self, t: Union[float, int], y: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any]) -> ~_VF abstractmethod ¤

The vector field.

Represents a function \(f(t, y(t), args)\).

Arguments:

  • t: the integration time.
  • y: the evolving state; a PyTree of structure \(T\).
  • args: any static arguments as passed to diffrax.diffeqsolve.

Returns:

A PyTree of structure \(S\).

contr(self, t0: Union[float, int], t1: Union[float, int], **kwargs) -> ~_Control abstractmethod ¤

The control.

Represents the \(\mathrm{d}t\) in an ODE, or the \(\mathrm{d}w(t)\) in an SDE, etc.

Most numerical ODE solvers work by making a step of length \(\Delta t = t_1 - t_0\). Likewise most numerical SDE solvers work by sampling some Brownian motion \(\Delta w \sim \mathcal{N}(0, t_1 - t_0)\).

Correspondingly a control is not defined at a point. Instead it is defined over an interval \([t_0, t_1]\).

Arguments:

  • t0: the start of the interval.
  • t1: the end of the interval.

Returns:

A PyTree of structure \(U\). For a control \(x\) then the result should represent \(x(t_1) - x(t_0)\).

prod(self, vf: ~_VF, control: ~_Control) -> PyTree[Shaped[ArrayLike, '?*y'], "Y"] abstractmethod ¤

Determines the interaction between vector field and control.

With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(f(t, y(t), args)\) and \(\Delta x(t)\).

Note

This function must be bilinear.

Arguments:

  • vf: The vector field evaluation; a PyTree of structure \(S\).
  • control: The control evaluated over an interval; a PyTree of structure \(U\).

Returns:

The interaction between the vector field and control; a PyTree of structure \(T\).

vf_prod(self, t: Union[float, int], y: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any], control: ~_Control) -> PyTree[Shaped[ArrayLike, '?*y'], "Y"] ¤

The composition of diffrax.AbstractTerm.vf and diffrax.AbstractTerm.prod.

With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(t\), \(y(t)\), \(args\), and \(\Delta x(t)\).

Its default implementation is simply

self.prod(self.vf(t, y, args), control)

This is offered as a special case that can be overridden when it is more efficient to do so.

Example

Consider when vf computes a matrix-matrix product, and prod computes a matrix-vector product. Then doing a naive composition corresponds to a (matrix-matrix)-vector product, which is less efficient than the corresponding matrix-(matrix-vector) product. Overriding this method offers a way to reclaim that efficiency.

Example

This is used extensively for efficiency when backpropagating via diffrax.BacksolveAdjoint.

Arguments:

  • t: the integration time.
  • y: the evolving state; a PyTree of structure \(T\).
  • args: any static arguments as passed to diffrax.diffeqsolve.
  • control: The control evaluated over an interval; a PyTree of structure \(U\).

Returns:

A PyTree of structure \(T\).

Note

This function must be linear in control.

is_vf_expensive(self, t0: Union[float, int], t1: Union[float, int], y: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any]) -> bool ¤

Specifies whether evaluating the vector field is "expensive", in the specific sense that it is cheaper to evaluate vf_prod twice than vf once.

Some solvers use this to change their behaviour, so as to act more efficiently.

¤

¤
¤
¤
¤
¤
Defining your own term types

For advanced users: you can create your own terms if appropriate. For example if your diffusion is matrix, itself computed as a matrix-matrix product, then you may wish to define a custom term and specify its diffrax.AbstractTerm.vf_prod method. By overriding this method you could express the contraction of the vector field - control as a matrix-(matix-vector) product, which is more efficient than the default (matrix-matrix)-vector product.


diffrax.ODETerm (AbstractTerm) ¤

A term representing \(f(t, y(t), args) \mathrm{d}t\). That is to say, the term appearing on the right hand side of an ODE, in which the control is time.

vector_field should return some PyTree, with the same structure as the initial state y0, and with every leaf shape-broadcastable and dtype-upcastable to the equivalent leaf in y0.

Example

vector_field = lambda t, y, args: -y
ode_term = ODETerm(vector_field)
diffeqsolve(ode_term, ...)
__init__(self, vector_field: Callable[[Union[float, int], PyTree[Shaped[ArrayLike, '?*y'], "Y"], PyTree[Any]], ~_VF]) ¤

Arguments:

  • vector_field: A callable representing the vector field. This callable takes three arguments (t, y, args). t is a scalar representing the integration time. y is the evolving state of the system. args are any static arguments as passed to diffrax.diffeqsolve.

diffrax.ControlTerm (AbstractTerm) ¤

A term representing the general case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field (\(f\)) - control (\(\mathrm{d}x\)) interaction is a matrix-vector product.

This is typically used for either stochastic differential equations or for controlled differential equations.

ControlTerm can be used in two different ways.

  1. Simple way: directly return JAX arrays.

    vector_field and control should both return PyTrees, both with the same structure as the initial state y0. All leaves should be JAX arrays.

    If each leaf of y0 has shape (y1, ..., yN), and the corresponding leaf of control has shape (c1, ..., cM), then the corresponding leaf of vector_field should have shape (y1, ..., yN, c1, ..., cM). Leaf-by-leaf, the corresponding dimensions of vector_field and control are contracted against each other.

    This includes normal matrix-vector products as a special case: when y0 is an array with shape (m,), the control is an array with shape (n,), and the vector field is an array with shape (m, n).

  2. Advanced way: have the vector field return a Lineax linear operator.

    This is suitable for use cases in which you know that the vector field has special structure -- e.g. it is diagonal -- and you would like to use that structure for a more efficient implementation.

    In this case, then vector_field should return a Lineax linear operator, the control can return anything compatible with the .mv method of that operator, and the interaction is defined as vector_field(t0, y, arg).mv(control(t0, t1)).

    In this case no special PyTree handling is done -- perform this inside the operator's .mv if required. (As you can see, this approach is basically about deferring the whole linear operation to Lineax.)

Example

In this example we consider an SDE with m-dimensional state \(y \in \mathbb{R}^m\), an n-dimensional Brownian motion \(W(t) \in \mathbb{R}^n\), and a constant diffusion of shape (m, n).

\(\mathrm{d}y(t) = \begin{bmatrix} 1 & ... & 1 \\ & ... & \\ 1 & ... & 1 \end{bmatrix} \mathrm{d}W(t)\)

from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath

y0 = jnp.ones((m,))
control = UnsafeBrownianPath(shape=(n,), key=...)

def vector_field(t, y, args):
    return jnp.ones((m, n))

diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(terms=diffusion_term, y0=y0, ...)

Example

In this example we consider an SDE with a one-dimensional state \(y(t) \in \mathbb{R}\) and a two-dimensional Brownian motion \(W(t) \in \mathbb{R}^2\), given by:

\(\mathrm{d}y(t) = \begin{bmatrix} y(t) \\ y(t) + 1 \end{bmatrix} \mathrm{d}W(t)\)

We use the simple matrix-vector product way of combining things.

from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath

control = UnsafeBrownianPath(shape=(2,), key=...)

def vector_field(t, y, args):
    return jnp.stack([y, y + 1], axis=-1)

diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)

Example

In this example we consider an SDE with two-dimensional state \((y_1(t), y_2(t)) \in \mathbb{R}^2\) and a two-dimensional Brownian motion \(W(t) \in \mathbb{R}^2\) -- and for which the diffusion matrix is diagonal.

\(\mathrm{d}\begin{bmatrix} y_1 \\ y_2 \end{bmatrix}(t) = \begin{bmatrix} y_2(t) & 0 \\ 0 & y_1(t) \end{bmatrix} \mathrm{d}W(t)\)

As such we use the more-advanced approach of using Lineax's linear operators to represent the diffusion matrix.

from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath

control = UnsafeBrownianPath(shape=(2,), key=...)

def vector_field(t, y, args):
    # y is a JAX array of shape (2,)
    y1, y2 = y
    diagonal = jnp.array([y2, y1])
    return lineax.DiagonalLinearOperator(diagonal)

diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)

Example

In this example we consider a controlled differnetial equation, for which the control is given by an interpolation of some data. (See also the neural controlled differential equation example.)

from diffrax import ControlTerm, diffeqsolve, LinearInterpolation, UnsafeBrownianPath

ts = jnp.array([1., 2., 2.5, 3.])
data = jnp.array([[0.1, 2.0],
                  [0.3, 1.5],
                  [1.0, 1.6],
                  [0.2, 1.1]])
control = LinearInterpolation(ts, data)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
cde_term = ControlTerm(vector_field, control)
diffeqsolve(cde_term, ...)
__init__(self, vector_field: Callable[[Union[float, int], PyTree[Shaped[ArrayLike, '?*y'], "Y"], PyTree[Any]], ~_VF], control: Union[AbstractPath[~_Control], Callable[[Union[float, int], Union[float, int]], ~_Control]]) ¤

Arguments:

  • vector_field: A callable representing the vector field. This callable takes three arguments (t, y, args). t is a scalar representing the integration time. y is the evolving state of the system. args are any static arguments as passed to diffrax.diffeqsolve. This vector_field can either be

    1. a function that returns a PyTree of JAX arrays, or
    2. it can return a Lineax linear operator, as described above.
  • control: The control. Should either be

    1. a diffrax.AbstractPath, in which case its .evaluate(t0, t1) method will be used to give the increment of the control over a time interval [t0, t1], or
    2. a callable (t0, t1) -> increment, which returns the increment directly.
to_ode(self) -> ODETerm ¤

If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as

\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).

This method converts this ControlTerm into the corresponding diffrax.ODETerm in this way.

diffrax.MultiTerm (AbstractTerm) ¤

Accumulates multiple terms into a single term.

Consider the SDE

\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)

This has two terms on the right hand side. It may be represented with a single term as

\(\mathrm{d}y(t) = [f(t, y(t)), g(t, y(t))] \cdot [\mathrm{d}t, \mathrm{d}w(t)]\)

whose vector field -- control interaction is a dot product.

MultiTerm performs this transform. For simplicitly most differential equation solvers (at least those built-in to Diffrax) accept just a single term, so this transform is a necessary part of e.g. solving an SDE with both drift and diffusion.

__init__(self, *terms: AbstractTerm) ¤

Arguments: