One of the advanced features of Diffrax is its term system. When we write down e.g. a stochastic differential equation
\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)
then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a vector field (\(f\) or \(g\)) and a control (\(\mathrm{d}t\) or \(\mathrm{d}w(t)\)). In addition (often not represented in mathematical notation), there is also a choice of how the vector field and control interact: \(f\) and \(\mathrm{d}t\) interact as a vector-scalar product. \(g\) and \(\mathrm{d}w(t)\) interact as a matrix-vector product. (In general this interaction is always bilinear.)
"Terms" are thus the building blocks of differential equations.
Consider the ODE \(\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))\). Then this has vector field \(f\), control \(\mathrm{d}t\), and their interaction is a vector-scalar product. This can be described as a single diffrax.ODETerm
Adding multiple terms, such as SDEs¤
We can add multiple terms together by grouping them into a single diffrax.MultiTerm
The SDE above would have its drift described by diffrax.ODETerm
and the diffusion described by a diffrax.ControlTerm
. As these affect the same evolving state variable, they should be passed to the solver as MultiTerm(ODETerm(...), ControlTerm(...))
Independent terms, such as Hamiltonian systems¤
If terms affect different pieces of the state, then they should be placed in some PyTree structure.
Consider the pair of equations (as commonly arising from Hamiltonian systems):
\(\frac{\mathrm{d}x}{\mathrm{d}t}(t) = f(t, y(t)),\qquad\frac{\mathrm{d}y}{\mathrm{d}t}(t) = g(t, x(t))\)
These would be passed to the solver as the 2-tuple of (ODETerm(...), ODETerm(...))
What each solver accepts¤
Each solver in Diffrax will specify what kinds of problems it can handle, as described by their .term_structure
attribute. Not all solvers are able to handle all problems!
Some example term structures include:
solver.term_structure = AbstractTerm
In this case the solver can handle a simple ODE as descibed above:
is a subclass ofAbstractTerm
.It can also handle SDEs:
MultiTerm(ODETerm(...), ControlTerm(...))
includes everything wrapped into a single term (theMultiTerm
), and at that point this defines an interface the solver knows how to handle.Most solvers in Diffrax have this term structure.
solver.term_structure = MultiTerm[tuple[ODETerm, ControlTerm]]
In this case the solver specifically handles just SDEs of the form
MultiTerm(ODETerm(...), ControlTerm(...))
; nothing else is compatible.Some SDE-specific solvers have this term structure.
solver.term_structure = (AbstractTerm, AbstractTerm)
In this case the solver is used to solve ODEs like the Hamiltonian system described above: we have a PyTree of terms, each of which is treated individually.
Abstract base class for all terms.
Let \(y\) solve some differential equation with vector field \(f\) and control \(x\).
Let \(y\) have PyTree structure \(T\), let the output of the vector field have PyTree structure \(S\), and let \(x\) have PyTree structure \(U\), Then \(f : T \to S\) whilst the interaction \((f, x) \mapsto f \mathrm{d}x\) is a function \((S, U) \to T\).
vf(self, t: Real[ArrayLike, ''], y: PyTree[Shaped[ArrayLike, '?*y'], 'Y'], args: PyTree[Any]) -> ~_VF
The vector field.
Represents a function \(f(t, y(t), args)\).
: the integration time.y
: the evolving state; a PyTree of structure \(T\).args
: any static arguments as passed todiffrax.diffeqsolve
A PyTree of structure \(S\).
contr(self, t0: Real[ArrayLike, ''], t1: Real[ArrayLike, ''], **kwargs) -> ~_Control
The control.
Represents the \(\mathrm{d}t\) in an ODE, or the \(\mathrm{d}w(t)\) in an SDE, etc.
Most numerical ODE solvers work by making a step of length \(\Delta t = t_1 - t_0\). Likewise most numerical SDE solvers work by sampling some Brownian motion \(\Delta w \sim \mathcal{N}(0, t_1 - t_0)\).
Correspondingly a control is not defined at a point. Instead it is defined over an interval \([t_0, t_1]\).
: the start of the interval.t1
: the end of the interval.
A PyTree of structure \(U\). For a control \(x\) then the result should represent \(x(t_1) - x(t_0)\).
prod(self, vf: ~_VF, control: ~_Control) -> PyTree[Shaped[ArrayLike, '?*y'], 'Y']
Determines the interaction between vector field and control.
With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(f(t, y(t), args)\) and \(\Delta x(t)\).
This function must be bilinear.
: The vector field evaluation; a PyTree of structure \(S\).control
: The control evaluated over an interval; a PyTree of structure \(U\).
The interaction between the vector field and control; a PyTree of structure \(T\).
vf_prod(self, t: Real[ArrayLike, ''], y: PyTree[Shaped[ArrayLike, '?*y'], 'Y'], args: PyTree[Any], control: ~_Control) -> PyTree[Shaped[ArrayLike, '?*y'], 'Y']
The composition of diffrax.AbstractTerm.vf
With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(t\), \(y(t)\), \(args\), and \(\Delta x(t)\).
Its default implementation is simply
self.prod(self.vf(t, y, args), control)
This is offered as a special case that can be overridden when it is more efficient to do so.
Consider when vf
computes a matrix-matrix product, and prod
computes a
matrix-vector product. Then doing a naive composition corresponds to a
(matrix-matrix)-vector product, which is less efficient than the
corresponding matrix-(matrix-vector) product. Overriding this method offers
a way to reclaim that efficiency.
This is used extensively for efficiency when backpropagating via
: the integration time.y
: the evolving state; a PyTree of structure \(T\).args
: any static arguments as passed todiffrax.diffeqsolve
: The control evaluated over an interval; a PyTree of structure \(U\).
A PyTree of structure \(T\).
This function must be linear in control
is_vf_expensive(self, t0: Real[ArrayLike, ''], t1: Real[ArrayLike, ''], y: PyTree[Shaped[ArrayLike, '?*y'], 'Y'], args: PyTree[Any]) -> bool
Specifies whether evaluating the vector field is "expensive", in the
specific sense that it is cheaper to evaluate vf_prod
twice than vf
Some solvers use this to change their behaviour, so as to act more efficiently.
Defining your own term types
For advanced users, you can create your own terms if appropriate. See for example the underdamped Langevin terms, which have their own special set of solvers.
diffrax.ODETerm (AbstractTerm)
A term representing \(f(t, y(t), args) \mathrm{d}t\). That is to say, the term appearing on the right hand side of an ODE, in which the control is time.
should return some PyTree, with the same structure as the initial
state y0
, and with every leaf shape-broadcastable and dtype-upcastable to the
equivalent leaf in y0
vector_field = lambda t, y, args: -y
ode_term = ODETerm(vector_field)
diffeqsolve(ode_term, ...)
__init__(self, vector_field: Callable[[Real[ArrayLike, ''], PyTree[Shaped[ArrayLike, '?*y'], 'Y'], PyTree[Any]], ~_VF])
: A callable representing the vector field. This callable takes three arguments(t, y, args)
is a scalar representing the integration time.y
is the evolving state of the system.args
are any static arguments as passed todiffrax.diffeqsolve
diffrax.ControlTerm (AbstractTerm)
A term representing the general case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field (\(f\)) - control (\(\mathrm{d}x\)) interaction is a matrix-vector product.
This is typically used for either stochastic differential equations or for controlled differential equations.
can be used in two different ways.
Simple way: directly return JAX arrays.
should both return PyTrees, both with the same structure as the initial statey0
. All leaves should be JAX arrays.If each leaf of
has shape(y1, ..., yN)
, and the corresponding leaf ofcontrol
has shape(c1, ..., cM)
, then the corresponding leaf ofvector_field
should have shape(y1, ..., yN, c1, ..., cM)
. Leaf-by-leaf, the corresponding dimensions ofvector_field
and control are contracted against each other.This includes normal matrix-vector products as a special case: when
is an array with shape(m,)
, the control is an array with shape(n,)
, and the vector field is an array with shape(m, n)
. -
Advanced way: have the vector field return a Lineax linear operator.
This is suitable for use cases in which you know that the vector field has special structure -- e.g. it is diagonal -- and you would like to use that structure for a more efficient implementation.
In this case, then
should return a Lineax linear operator, the control can return anything compatible with the.mv
method of that operator, and the interaction is defined asvector_field(t0, y, arg).mv(control(t0, t1))
.In this case no special PyTree handling is done -- perform this inside the operator's
if required. (As you can see, this approach is basically about deferring the whole linear operation to Lineax.)
In this example we consider an SDE with m
-dimensional state
\(y \in \mathbb{R}^m\), an n
-dimensional Brownian motion
\(W(t) \in \mathbb{R}^n\), and a constant diffusion of shape (m, n)
\(\mathrm{d}y(t) = \begin{bmatrix} 1 & ... & 1 \\ & ... & \\ 1 & ... & 1 \end{bmatrix} \mathrm{d}W(t)\)
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
y0 = jnp.ones((m,))
control = UnsafeBrownianPath(shape=(n,), key=...)
def vector_field(t, y, args):
return jnp.ones((m, n))
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(terms=diffusion_term, y0=y0, ...)
In this example we consider an SDE with a one-dimensional state \(y(t) \in \mathbb{R}\) and a two-dimensional Brownian motion \(W(t) \in \mathbb{R}^2\), given by:
\(\mathrm{d}y(t) = \begin{bmatrix} y(t) \\ y(t) + 1 \end{bmatrix} \mathrm{d}W(t)\)
We use the simple matrix-vector product way of combining things.
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
control = UnsafeBrownianPath(shape=(2,), key=...)
def vector_field(t, y, args):
return jnp.stack([y, y + 1], axis=-1)
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)
In this example we consider an SDE with two-dimensional state \((y_1(t), y_2(t)) \in \mathbb{R}^2\) and a two-dimensional Brownian motion \(W(t) \in \mathbb{R}^2\) -- and for which the diffusion matrix is diagonal.
\(\mathrm{d}\begin{bmatrix} y_1 \\ y_2 \end{bmatrix}(t) = \begin{bmatrix} y_2(t) & 0 \\ 0 & y_1(t) \end{bmatrix} \mathrm{d}W(t)\)
As such we use the more-advanced approach of using Lineax's linear operators to represent the diffusion matrix.
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
control = UnsafeBrownianPath(shape=(2,), key=...)
def vector_field(t, y, args):
# y is a JAX array of shape (2,)
y1, y2 = y
diagonal = jnp.array([y2, y1])
return lineax.DiagonalLinearOperator(diagonal)
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)
In this example we consider a controlled differnetial equation, for which the control is given by an interpolation of some data. (See also the neural controlled differential equation example.)
from diffrax import ControlTerm, diffeqsolve, LinearInterpolation, UnsafeBrownianPath
ts = jnp.array([1., 2., 2.5, 3.])
data = jnp.array([[0.1, 2.0],
[0.3, 1.5],
[1.0, 1.6],
[0.2, 1.1]])
control = LinearInterpolation(ts, data)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
cde_term = ControlTerm(vector_field, control)
diffeqsolve(cde_term, ...)
__init__(self, vector_field: Callable[[Real[ArrayLike, ''], PyTree[Shaped[ArrayLike, '?*y'], 'Y'], PyTree[Any]], ~_VF], control: Union[AbstractPath[~_Control], Callable[[Real[ArrayLike, ''], Real[ArrayLike, '']], ~_Control]])
: A callable representing the vector field. This callable takes three arguments(t, y, args)
is a scalar representing the integration time.y
is the evolving state of the system.args
are any static arguments as passed todiffrax.diffeqsolve
. Thisvector_field
can either be- a function that returns a PyTree of JAX arrays, or
- it can return a Lineax linear operator, as described above.
: The control. Should either be- a
, in which case its.evaluate(t0, t1)
method will be used to give the increment of the control over a time interval[t0, t1]
, or - a callable
(t0, t1) -> increment
, which returns the increment directly.
- a
to_ode(self) -> ODETerm
If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as
\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).
This method converts this ControlTerm
into the corresponding
in this way.
diffrax.MultiTerm (AbstractTerm)
Accumulates multiple terms into a single term.
Consider the SDE
\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)
This has two terms on the right hand side. It may be represented with a single term as
\(\mathrm{d}y(t) = [f(t, y(t)), g(t, y(t))] \cdot [\mathrm{d}t, \mathrm{d}w(t)]\)
whose vector field -- control interaction is a dot product.
performs this transform. For simplicitly most differential equation
solvers (at least those built-in to Diffrax) accept just a single term, so this
transform is a necessary part of e.g. solving an SDE with both drift and diffusion.
__init__(self, *terms: AbstractTerm)
: Any number ofdiffrax.AbstractTerm
s to combine.
Underdamped Langevin terms¤
These are special terms which describe the Underdamped Langevin diffusion (ULD), which takes the form
where \(x(t), v(t) \in \mathbb{R}^d\) represent the position and velocity, \(w\) is a Brownian motion in \(\mathbb{R}^d\), \(f: \mathbb{R}^d \rightarrow \mathbb{R}\) is a potential function, and \(\gamma , u \in \mathbb{R}^{d \times d}\) are diagonal matrices governing the friction and the damping of the system.
These terms enable the use of ULD-specific solvers which can be found
here. These ULD solvers expect
terms with structure MultiTerm(UnderdampedLangevinDriftTerm(gamma, u, grad_f), UnderdampedLangevinDiffusionTerm(gamma, u, bm))
where bm
is an diffrax.AbstractBrownianPath
and the same values of gammma
and u
are passed to both terms.
diffrax.UnderdampedLangevinDriftTerm (AbstractTerm)
Represents the drift term in the Underdamped Langevin Diffusion (ULD). The ULD SDE takes the form:
where \(x(t), v(t) \in \mathbb{R}^d\) represent the position and velocity, \(w\) is a Brownian motion in \(\mathbb{R}^d\), \(f: \mathbb{R}^d \rightarrow \mathbb{R}\) is a potential function, and \(\gamma , u \in \mathbb{R}^{d \times d}\) are diagonal matrices governing the friction and the damping of the system.
__init__(self, gamma: PyTree[ArrayLike], u: PyTree[ArrayLike], grad_f: Callable[[PyTree[Shaped[Array, '?*underdamped_langevin'], 'UnderdampedLangevinX'], PyTree[Any]], PyTree[Shaped[Array, '?*underdamped_langevin'], 'UnderdampedLangevinX']])
: A vector containing the diagonal entries of the friction matrix; a scalar or a PyTree of the same shape as the position vector \(x\).u
: A vector containing the diagonal entries of the damping matrix; a scalar or a PyTree of the same shape as the position vector \(x\).grad_f
: A callable representing the gradient of the potential function \(f\). This callable should take a PyTree of the same shape as \(x\) and an optionalargs
argument, returning a PyTree of the same shape.
diffrax.UnderdampedLangevinDiffusionTerm (AbstractTerm)
Represents the diffusion term in the Underdamped Langevin Diffusion (ULD). The ULD SDE takes the form:
where \(x(t), v(t) \in \mathbb{R}^d\) represent the position and velocity, \(w\) is a Brownian motion in \(\mathbb{R}^d\), \(f: \mathbb{R}^d \rightarrow \mathbb{R}\) is a potential function, and \(\gamma , u \in \mathbb{R}^{d \times d}\) are diagonal matrices governing the friction and the damping of the system.
__init__(self, gamma: PyTree[ArrayLike], u: PyTree[ArrayLike], bm: AbstractBrownianPath)
: A vector containing the diagonal entries of the friction matrix; a scalar or a PyTree of the same shape as the position vector \(x\).u
: A vector containing the diagonal entries of the damping matrix; a scalar or a PyTree of the same shape as the position vector \(x\).bm
: A Brownian path representing the Brownian motion \(w\).