Terms¤
One of the advanced features of Diffrax is its term system. When we write down e.g. a stochastic differential equation
\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)
then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a vector field (\(f\) or \(g\)) and a control (\(\mathrm{d}t\) or \(\mathrm{d}w(t)\)). In addition (often not represented in mathematical notation), there is also a choice of how the vector field and control interact: \(f\) and \(\mathrm{d}t\) interact as a vector-scalar product. \(g\) and \(\mathrm{d}w(t)\) interact as a matrix-vector product. (In general this interaction is always bilinear.)
"Terms" are thus the building blocks of differential equations.
Example
Consider the ODE \(\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))\). Then this has vector field \(f\), control \(\mathrm{d}t\), and their interaction is a vector-scalar product. This can be described as a single diffrax.ODETerm
.
Adding multiple terms, such as SDEs¤
We can add multiple terms together by grouping them into a single diffrax.MultiTerm
.
Example
The SDE above would have its drift described by diffrax.ODETerm
and the diffusion described by a diffrax.ControlTerm
. As these affect the same evolving state variable, they should be passed to the solver as MultiTerm(ODETerm(...), ControlTerm(...))
.
Independent terms, such as Hamiltonian systems¤
If terms affect different pieces of the state, then they should be placed in some PyTree structure.
Example
Consider the pair of equations (as commonly arising from Hamiltonian systems):
\(\frac{\mathrm{d}x}{\mathrm{d}t}(t) = f(t, y(t)),\qquad\frac{\mathrm{d}y}{\mathrm{d}t}(t) = g(t, x(t))\)
These would be passed to the solver as the 2-tuple of (ODETerm(...), ODETerm(...))
.
What each solver accepts¤
Each solver in Diffrax will specify what kinds of problems it can handle, as described by their .term_structure
attribute. Not all solvers are able to handle all problems!
Some example term structures include:
-
solver.term_structure = AbstractTerm
In this case the solver can handle a simple ODE as descibed above:
ODETerm
is a subclass ofAbstractTerm
.It can also handle SDEs:
MultiTerm(ODETerm(...), ControlTerm(...))
includes everything wrapped into a single term (theMultiTerm
), and at that point this defines an interface the solver knows how to handle.Most solvers in Diffrax have this term structure.
-
solver.term_structure = MultiTerm[tuple[ODETerm, ControlTerm]]
In this case the solver specifically handles just SDEs of the form
MultiTerm(ODETerm(...), ControlTerm(...))
; nothing else is compatible.Some SDE-specific solvers have this term structure.
-
solver.term_structure = (AbstractTerm, AbstractTerm)
In this case the solver is used to solve ODEs like the Hamiltonian system described above: we have a PyTree of terms, each of which is treated individually.
diffrax.AbstractTerm
diffrax.AbstractTerm
¤
Abstract base class for all terms.
Let \(y\) solve some differential equation with vector field \(f\) and control \(x\).
Let \(y\) have PyTree structure \(T\), let the output of the vector field have PyTree structure \(S\), and let \(x\) have PyTree structure \(U\), Then \(f : T \to S\) whilst the interaction \((f, x) \mapsto f \mathrm{d}x\) is a function \((S, U) \to T\).
vf(self, t: Union[float, int], y: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any]) -> ~_VF
abstractmethod
¤
The vector field.
Represents a function \(f(t, y(t), args)\).
Arguments:
t
: the integration time.y
: the evolving state; a PyTree of structure \(T\).args
: any static arguments as passed todiffrax.diffeqsolve
.
Returns:
A PyTree of structure \(S\).
contr(self, t0: Union[float, int], t1: Union[float, int], **kwargs) -> ~_Control
abstractmethod
¤
The control.
Represents the \(\mathrm{d}t\) in an ODE, or the \(\mathrm{d}w(t)\) in an SDE, etc.
Most numerical ODE solvers work by making a step of length \(\Delta t = t_1 - t_0\). Likewise most numerical SDE solvers work by sampling some Brownian motion \(\Delta w \sim \mathcal{N}(0, t_1 - t_0)\).
Correspondingly a control is not defined at a point. Instead it is defined over an interval \([t_0, t_1]\).
Arguments:
t0
: the start of the interval.t1
: the end of the interval.
Returns:
A PyTree of structure \(U\). For a control \(x\) then the result should represent \(x(t_1) - x(t_0)\).
prod(self, vf: ~_VF, control: ~_Control) -> PyTree[Shaped[ArrayLike, '?*y'], "Y"]
abstractmethod
¤
Determines the interaction between vector field and control.
With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(f(t, y(t), args)\) and \(\Delta x(t)\).
Note
This function must be bilinear.
Arguments:
vf
: The vector field evaluation; a PyTree of structure \(S\).control
: The control evaluated over an interval; a PyTree of structure \(U\).
Returns:
The interaction between the vector field and control; a PyTree of structure \(T\).
vf_prod(self, t: Union[float, int], y: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any], control: ~_Control) -> PyTree[Shaped[ArrayLike, '?*y'], "Y"]
¤
The composition of diffrax.AbstractTerm.vf
and
diffrax.AbstractTerm.prod
.
With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(t\), \(y(t)\), \(args\), and \(\Delta x(t)\).
Its default implementation is simply
self.prod(self.vf(t, y, args), control)
This is offered as a special case that can be overridden when it is more efficient to do so.
Example
Consider when vf
computes a matrix-matrix product, and prod
computes a
matrix-vector product. Then doing a naive composition corresponds to a
(matrix-matrix)-vector product, which is less efficient than the
corresponding matrix-(matrix-vector) product. Overriding this method offers
a way to reclaim that efficiency.
Example
This is used extensively for efficiency when backpropagating via
diffrax.BacksolveAdjoint
.
Arguments:
t
: the integration time.y
: the evolving state; a PyTree of structure \(T\).args
: any static arguments as passed todiffrax.diffeqsolve
.control
: The control evaluated over an interval; a PyTree of structure \(U\).
Returns:
A PyTree of structure \(T\).
Note
This function must be linear in control
.
is_vf_expensive(self, t0: Union[float, int], t1: Union[float, int], y: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any]) -> bool
¤
Specifies whether evaluating the vector field is "expensive", in the
specific sense that it is cheaper to evaluate vf_prod
twice than vf
once.
Some solvers use this to change their behaviour, so as to act more efficiently.
¤
¤
¤
¤
¤
¤
Defining your own term types
For advanced users: you can create your own terms if appropriate. For example if your diffusion is matrix, itself computed as a matrix-matrix product, then you may wish to define a custom term and specify its diffrax.AbstractTerm.vf_prod
method. By overriding this method you could express the contraction of the vector field - control as a matrix-(matix-vector) product, which is more efficient than the default (matrix-matrix)-vector product.
diffrax.ODETerm (AbstractTerm)
¤
A term representing \(f(t, y(t), args) \mathrm{d}t\). That is to say, the term appearing on the right hand side of an ODE, in which the control is time.
vector_field
should return some PyTree, with the same structure as the initial
state y0
, and with every leaf shape-broadcastable and dtype-upcastable to the
equivalent leaf in y0
.
Example
vector_field = lambda t, y, args: -y
ode_term = ODETerm(vector_field)
diffeqsolve(ode_term, ...)
__init__(self, vector_field: Callable[[Union[float, int], PyTree[Shaped[ArrayLike, '?*y'], "Y"], PyTree[Any]], ~_VF])
¤
Arguments:
vector_field
: A callable representing the vector field. This callable takes three arguments(t, y, args)
.t
is a scalar representing the integration time.y
is the evolving state of the system.args
are any static arguments as passed todiffrax.diffeqsolve
.
diffrax.ControlTerm (AbstractTerm)
¤
A term representing the general case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field (\(f\)) - control (\(\mathrm{d}x\)) interaction is a matrix-vector product.
This is typically used for either stochastic differential equations or for controlled differential equations.
ControlTerm
can be used in two different ways.
-
Simple way: directly return JAX arrays.
vector_field
andcontrol
should both return PyTrees, both with the same structure as the initial statey0
. All leaves should be JAX arrays.If each leaf of
y0
has shape(y1, ..., yN)
, and the corresponding leaf ofcontrol
has shape(c1, ..., cM)
, then the corresponding leaf ofvector_field
should have shape(y1, ..., yN, c1, ..., cM)
. Leaf-by-leaf, the corresponding dimensions ofvector_field
and control are contracted against each other.This includes normal matrix-vector products as a special case: when
y0
is an array with shape(m,)
, the control is an array with shape(n,)
, and the vector field is an array with shape(m, n)
. -
Advanced way: have the vector field return a Lineax linear operator.
This is suitable for use cases in which you know that the vector field has special structure -- e.g. it is diagonal -- and you would like to use that structure for a more efficient implementation.
In this case, then
vector_field
should return a Lineax linear operator, the control can return anything compatible with the.mv
method of that operator, and the interaction is defined asvector_field(t0, y, arg).mv(control(t0, t1))
.In this case no special PyTree handling is done -- perform this inside the operator's
.mv
if required. (As you can see, this approach is basically about deferring the whole linear operation to Lineax.)
Example
In this example we consider an SDE with m
-dimensional state
\(y \in \mathbb{R}^m\), an n
-dimensional Brownian motion
\(W(t) \in \mathbb{R}^n\), and a constant diffusion of shape (m, n)
.
\(\mathrm{d}y(t) = \begin{bmatrix} 1 & ... & 1 \\ & ... & \\ 1 & ... & 1 \end{bmatrix} \mathrm{d}W(t)\)
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
y0 = jnp.ones((m,))
control = UnsafeBrownianPath(shape=(n,), key=...)
def vector_field(t, y, args):
return jnp.ones((m, n))
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(terms=diffusion_term, y0=y0, ...)
Example
In this example we consider an SDE with a one-dimensional state \(y(t) \in \mathbb{R}\) and a two-dimensional Brownian motion \(W(t) \in \mathbb{R}^2\), given by:
\(\mathrm{d}y(t) = \begin{bmatrix} y(t) \\ y(t) + 1 \end{bmatrix} \mathrm{d}W(t)\)
We use the simple matrix-vector product way of combining things.
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
control = UnsafeBrownianPath(shape=(2,), key=...)
def vector_field(t, y, args):
return jnp.stack([y, y + 1], axis=-1)
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)
Example
In this example we consider an SDE with two-dimensional state \((y_1(t), y_2(t)) \in \mathbb{R}^2\) and a two-dimensional Brownian motion \(W(t) \in \mathbb{R}^2\) -- and for which the diffusion matrix is diagonal.
\(\mathrm{d}\begin{bmatrix} y_1 \\ y_2 \end{bmatrix}(t) = \begin{bmatrix} y_2(t) & 0 \\ 0 & y_1(t) \end{bmatrix} \mathrm{d}W(t)\)
As such we use the more-advanced approach of using Lineax's linear operators to represent the diffusion matrix.
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
control = UnsafeBrownianPath(shape=(2,), key=...)
def vector_field(t, y, args):
# y is a JAX array of shape (2,)
y1, y2 = y
diagonal = jnp.array([y2, y1])
return lineax.DiagonalLinearOperator(diagonal)
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)
Example
In this example we consider a controlled differnetial equation, for which the control is given by an interpolation of some data. (See also the neural controlled differential equation example.)
from diffrax import ControlTerm, diffeqsolve, LinearInterpolation, UnsafeBrownianPath
ts = jnp.array([1., 2., 2.5, 3.])
data = jnp.array([[0.1, 2.0],
[0.3, 1.5],
[1.0, 1.6],
[0.2, 1.1]])
control = LinearInterpolation(ts, data)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
cde_term = ControlTerm(vector_field, control)
diffeqsolve(cde_term, ...)
__init__(self, vector_field: Callable[[Union[float, int], PyTree[Shaped[ArrayLike, '?*y'], "Y"], PyTree[Any]], ~_VF], control: Union[AbstractPath[~_Control], Callable[[Union[float, int], Union[float, int]], ~_Control]])
¤
Arguments:
-
vector_field
: A callable representing the vector field. This callable takes three arguments(t, y, args)
.t
is a scalar representing the integration time.y
is the evolving state of the system.args
are any static arguments as passed todiffrax.diffeqsolve
. Thisvector_field
can either be- a function that returns a PyTree of JAX arrays, or
- it can return a Lineax linear operator, as described above.
-
control
: The control. Should either be- a
diffrax.AbstractPath
, in which case its.evaluate(t0, t1)
method will be used to give the increment of the control over a time interval[t0, t1]
, or - a callable
(t0, t1) -> increment
, which returns the increment directly.
- a
to_ode(self) -> ODETerm
¤
If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as
\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).
This method converts this ControlTerm
into the corresponding
diffrax.ODETerm
in this way.
diffrax.MultiTerm (AbstractTerm)
¤
Accumulates multiple terms into a single term.
Consider the SDE
\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)
This has two terms on the right hand side. It may be represented with a single term as
\(\mathrm{d}y(t) = [f(t, y(t)), g(t, y(t))] \cdot [\mathrm{d}t, \mathrm{d}w(t)]\)
whose vector field -- control interaction is a dot product.
MultiTerm
performs this transform. For simplicitly most differential equation
solvers (at least those built-in to Diffrax) accept just a single term, so this
transform is a necessary part of e.g. solving an SDE with both drift and diffusion.
__init__(self, *terms: AbstractTerm)
¤
Arguments:
*terms
: Any number ofdiffrax.AbstractTerm
s to combine.