Skip to content

Terms¤

One of the advanced features of Diffrax is its term system. When we write down e.g. a stochastic differential equation

\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)

then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a vector field (\(f\) or \(g\)) and a control (\(\mathrm{d}t\) or \(\mathrm{d}w(t)\)). There is also an implicit assumption about how the vector field and control interact: \(f\) and \(\mathrm{d}t\) interact as a vector-scalar product. \(g\) and \(\mathrm{d}w(t)\) interact as a matrix-vector product. (This interaction is always linear.)

"Terms" are thus the building blocks of differential equations.

Example

Consider the ODE \(\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))\). Then this has vector field \(f\), control \(\mathrm{d}t\), and their interaction is a vector-scalar product. This can be described as a single diffrax.ODETerm.

If multiple terms affect the same evolving state, then they should be grouped into a single diffrax.MultiTerm.

Example

An SDE would have its drift described by diffrax.ODETerm and the diffusion described by a diffrax.ControlTerm. As these affect the same evolving state variable, they should be passed to the solver as MultiTerm(ODETerm(...), ControlTerm(...)).

If terms affect different pieces of the state, then they should be placed in some PyTree structure. (The exact structure will depend on what the solver accepts.)

Example

Consider the pair of equations (as commonly arising from Hamiltonian systems):

\(\frac{\mathrm{d}x}{\mathrm{d}t}(t) = f(t, y(t)),\qquad\frac{\mathrm{d}y}{\mathrm{d}t}(t) = g(t, x(t))\)

These would be passed to the solver as the 2-tuple of (ODETerm(...), ODETerm(...)).

Each solver is capable of handling certain classes of problems, as described by their solver.term_structure.

diffrax.AbstractTerm

diffrax.AbstractTerm ¤

Abstract base class for all terms.

Let \(y\) solve some differential equation with vector field \(f\) and control \(x\).

Let \(y\) have PyTree structure \(T\), let the output of the vector field have PyTree structure \(S\), and let \(x\) have PyTree structure \(U\), Then \(f : T \to S\) whilst the interaction \((f, x) \mapsto f \mathrm{d}x\) is a function \((S, U) \to T\).

vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree abstractmethod ¤

The vector field.

Represents a function \(f(t, y(t), args)\).

Arguments:

  • t: the integration time.
  • y: the evolving state; a PyTree of structure \(T\).
  • args: any static arguments as passed to diffrax.diffeqsolve.

Returns:

A PyTree of structure \(S\).

contr(self, t0: Scalar, t1: Scalar) -> PyTree abstractmethod ¤

The control.

Represents the \(\mathrm{d}t\) in an ODE, or the \(\mathrm{d}w(t)\) in an SDE, etc.

Most numerical ODE solvers work by making a step of length \(\Delta t = t_1 - t_0\). Likewise most numerical SDE solvers work by sampling some Brownian motion \(\Delta w \sim \mathcal{N}(0, t_1 - t_0)\).

Correspondingly a control is not defined at a point. Instead it is defined over an interval \([t_0, t_1]\).

Arguments:

  • t0: the start of the interval.
  • t1: the end of the interval.

Returns:

A PyTree of structure \(U\). For a control \(x\) then the result should represent \(x(t_1) - x(t_0)\).

prod(self, vf: PyTree, control: PyTree) -> PyTree abstractmethod ¤

Determines the interaction between vector field and control.

With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(f(t, y(t), args)\) and \(\Delta x(t)\).

Arguments:

  • vf: The vector field evaluation; a PyTree of structure \(S\).
  • control: The control evaluated over an interval; a PyTree of structure \(U\).

Returns:

The interaction between the vector field and control; a PyTree of structure \(T\).

Note

This function must be bilinear.

vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree ¤

The composition of diffrax.AbstractTerm.vf and diffrax.AbstractTerm.prod.

With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(t\), \(y(t)\), \(args\), and \(\Delta x(t)\).

Its default implementation is simply

self.prod(self.vf(t, y, args), control)

This is offered as a special case that can be overridden when it is more efficient to do so.

Example

Consider when vf computes a matrix-matrix product, and prod computes a matrix-vector product. Then doing a naive composition corresponds to a (matrix-matrix)-vector product, which is less efficient than the corresponding matrix-(matrix-vector) product. Overriding this method offers a way to reclaim that efficiency.

Example

This is used extensively for efficiency when backpropagating via diffrax.BacksolveAdjoint.

Arguments:

  • t: the integration time.
  • y: the evolving state; a PyTree of structure \(T\).
  • args: any static arguments as passed to diffrax.diffeqsolve.
  • control: The control evaluated over an interval; a PyTree of structure \(U\).

Returns:

A PyTree of structure \(T\).

Note

This function must be linear in control.

is_vf_expensive(self, t0: Scalar, t1: Scalar, y: Tuple[PyTree, PyTree, PyTree, PyTree], args: PyTree) -> bool ¤

Specifies whether evaluating the vector field is "expensive", in the specific sense that it is cheaper to evaluate vf_prod twice than vf once.

Some solvers use this to change their behaviour, so as to act more efficiently.

¤

¤
¤
¤
¤
¤

Note

You can create your own terms if appropriate: e.g. if a diffusion matrix has some particular structure, and you want to use a specialised more efficient matrix-vector product algorithm in prod. For example this is what diffrax.WeaklyDiagonalControlTerm does, as compared to just diffrax.ControlTerm.

diffrax.ODETerm (AbstractTerm) ¤

A term representing \(f(t, y(t), args) \mathrm{d}t\). That is to say, the term appearing on the right hand side of an ODE, in which the control is time.

vector_field should return some PyTree, with the same structure as the initial state y0, and with every leaf broadcastable to the equivalent leaf in y0.

Example

vector_field = lambda t, y, args: -y
ode_term = ODETerm(vector_field)
diffeqsolve(ode_term, ...)
__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree]) ¤

Arguments:

  • vector_field: A callable representing the vector field. This callable takes three arguments (t, y, args). t is a scalar representing the integration time. y is the evolving state of the system. args are any static arguments as passed to diffrax.diffeqsolve.

diffrax.ControlTerm (AbstractTerm) ¤

A term representing the general case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field - control interaction is a matrix-vector product.

vector_field and control should both return PyTrees, both with the same structure as the initial state y0. Every dimension of control is then contracted against the last dimensions of vector_field; that is to say if each leaf of y0 has shape (y1, ..., yN), and the corresponding leaf of control has shape (c1, ..., cM), then the corresponding leaf of vector_field should have shape (y1, ..., yN, c1, ..., cM).

A common special case is when y0 and control are vector-valued, and vector_field is matrix-valued.

Example

control = UnsafeBrownianPath(shape=(2,), key=...)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)

Example

ts = jnp.array([1., 2., 2.5, 3.])
data = jnp.array([[0.1, 2.0],
                  [0.3, 1.5],
                  [1.0, 1.6],
                  [0.2, 1.1]])
control = LinearInterpolation(ts, data)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
cde_term = ControlTerm(vector_field, control)
diffeqsolve(cde_term, ...)
__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree], control: AbstractPath) ¤

Arguments:

  • vector_field: A callable representing the vector field. This callable takes three arguments (t, y, args). t is a scalar representing the integration time. y is the evolving state of the system. args are any static arguments as passed to diffrax.diffeqsolve.
  • control: A callable representing the control. Should have an evaluate(t0, t1) method. If using diffrax.ControlTerm.to_ode then it should have a derivative(t) method.
to_ode(self) -> ODETerm ¤

If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as

\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).

This method converts this ControlTerm into the corresponding diffrax.ODETerm in this way.

diffrax.WeaklyDiagonalControlTerm (AbstractTerm) ¤

A term representing the case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field - control interaction is a matrix-vector product, and the matrix is square and diagonal. In this case we may represent the matrix as a vector of just its diagonal elements. The matrix-vector product may be calculated by pointwise multiplying this vector with the control; this is more computationally efficient than writing out the full matrix and then doing a full matrix-vector product.

Correspondingly, vector_field and control should both return PyTrees, and both should have the same structure and leaf shape as the initial state y0. These are multiplied together pointwise.

Info

Why "weakly" diagonal? Consider the matrix representation of the vector field, as a square diagonal matrix. In general, the (i,i)-th element may depending upon any of the values of y. It is only if the (i,i)-th element only depends upon the i-th element of y that the vector field is said to be "diagonal", without the "weak". (This stronger property is useful in some SDE solvers.)

__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree], control: AbstractPath) ¤

Arguments:

  • vector_field: A callable representing the vector field. This callable takes three arguments (t, y, args). t is a scalar representing the integration time. y is the evolving state of the system. args are any static arguments as passed to diffrax.diffeqsolve.
  • control: A callable representing the control. Should have an evaluate(t0, t1) method. If using diffrax.ControlTerm.to_ode then it should have a derivative(t) method.
to_ode(self) -> ODETerm ¤

If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as

\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).

This method converts this ControlTerm into the corresponding diffrax.ODETerm in this way.

diffrax.MultiTerm (AbstractTerm) ¤

Accumulates multiple terms into a single term.

Consider the SDE

\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)

This has two terms on the right hand side. It may be represented with a single term as

\(\mathrm{d}y(t) = [f(t, y(t)), g(t, y(t))] \cdot [\mathrm{d}t, \mathrm{d}w(t)]\)

whose vector field -- control interaction is a dot product.

MultiTerm performs this transform. For simplicitly most differential equation solvers (at least those built-in to Diffrax) accept just a single term, so this transform is a necessary part of e.g. solving an SDE with both drift and diffusion.

__init__(self, *terms: AbstractTerm) ¤

Arguments: