Terms¤
One of the advanced features of Diffrax is its term system. When we write down e.g. a stochastic differential equation
\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)
then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a vector field (\(f\) or \(g\)) and a control (\(\mathrm{d}t\) or \(\mathrm{d}w(t)\)). There is also an implicit assumption about how the vector field and control interact: \(f\) and \(\mathrm{d}t\) interact as a vector-scalar product. \(g\) and \(\mathrm{d}w(t)\) interact as a matrix-vector product. (This interaction is always linear.)
"Terms" are thus the building blocks of differential equations.
Example
Consider the ODE \(\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))\). Then this has vector field \(f\), control \(\mathrm{d}t\), and their interaction is a vector-scalar product. This can be described as a single diffrax.ODETerm
.
If multiple terms affect the same evolving state, then they should be grouped into a single diffrax.MultiTerm
.
Example
An SDE would have its drift described by diffrax.ODETerm
and the diffusion described by a diffrax.ControlTerm
. As these affect the same evolving state variable, they should be passed to the solver as MultiTerm(ODETerm(...), ControlTerm(...))
.
If terms affect different pieces of the state, then they should be placed in some PyTree structure. (The exact structure will depend on what the solver accepts.)
Example
Consider the pair of equations (as commonly arising from Hamiltonian systems):
\(\frac{\mathrm{d}x}{\mathrm{d}t}(t) = f(t, y(t)),\qquad\frac{\mathrm{d}y}{\mathrm{d}t}(t) = g(t, x(t))\)
These would be passed to the solver as the 2-tuple of (ODETerm(...), ODETerm(...))
.
Each solver is capable of handling certain classes of problems, as described by their solver.term_structure
.
diffrax.AbstractTerm
diffrax.AbstractTerm
¤
Abstract base class for all terms.
Let \(y\) solve some differential equation with vector field \(f\) and control \(x\).
Let \(y\) have PyTree structure \(T\), let the output of the vector field have PyTree structure \(S\), and let \(x\) have PyTree structure \(U\), Then \(f : T \to S\) whilst the interaction \((f, x) \mapsto f \mathrm{d}x\) is a function \((S, U) \to T\).
vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree
abstractmethod
¤
The vector field.
Represents a function \(f(t, y(t), args)\).
Arguments:
t
: the integration time.y
: the evolving state; a PyTree of structure \(T\).args
: any static arguments as passed todiffrax.diffeqsolve
.
Returns:
A PyTree of structure \(S\).
contr(self, t0: Scalar, t1: Scalar) -> PyTree
abstractmethod
¤
The control.
Represents the \(\mathrm{d}t\) in an ODE, or the \(\mathrm{d}w(t)\) in an SDE, etc.
Most numerical ODE solvers work by making a step of length \(\Delta t = t_1 - t_0\). Likewise most numerical SDE solvers work by sampling some Brownian motion \(\Delta w \sim \mathcal{N}(0, t_1 - t_0)\).
Correspondingly a control is not defined at a point. Instead it is defined over an interval \([t_0, t_1]\).
Arguments:
t0
: the start of the interval.t1
: the end of the interval.
Returns:
A PyTree of structure \(U\). For a control \(x\) then the result should represent \(x(t_1) - x(t_0)\).
prod(self, vf: PyTree, control: PyTree) -> PyTree
abstractmethod
¤
Determines the interaction between vector field and control.
With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(f(t, y(t), args)\) and \(\Delta x(t)\).
Arguments:
vf
: The vector field evaluation; a PyTree of structure \(S\).control
: The control evaluated over an interval; a PyTree of structure \(U\).
Returns:
The interaction between the vector field and control; a PyTree of structure \(T\).
Note
This function must be bilinear.
vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree
¤
The composition of diffrax.AbstractTerm.vf
and
diffrax.AbstractTerm.prod
.
With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(t\), \(y(t)\), \(args\), and \(\Delta x(t)\).
Its default implementation is simply
self.prod(self.vf(t, y, args), control)
This is offered as a special case that can be overridden when it is more efficient to do so.
Example
Consider when vf
computes a matrix-matrix product, and prod
computes a
matrix-vector product. Then doing a naive composition corresponds to a
(matrix-matrix)-vector product, which is less efficient than the
corresponding matrix-(matrix-vector) product. Overriding this method offers
a way to reclaim that efficiency.
Example
This is used extensively for efficiency when backpropagating via
diffrax.BacksolveAdjoint
.
Arguments:
t
: the integration time.y
: the evolving state; a PyTree of structure \(T\).args
: any static arguments as passed todiffrax.diffeqsolve
.control
: The control evaluated over an interval; a PyTree of structure \(U\).
Returns:
A PyTree of structure \(T\).
Note
This function must be linear in control
.
is_vf_expensive(self, t0: Scalar, t1: Scalar, y: Tuple[PyTree, PyTree, PyTree, PyTree], args: PyTree) -> bool
¤
Specifies whether evaluating the vector field is "expensive", in the
specific sense that it is cheaper to evaluate vf_prod
twice than vf
once.
Some solvers use this to change their behaviour, so as to act more efficiently.
¤
¤
¤
¤
¤
¤
Note
You can create your own terms if appropriate: e.g. if a diffusion matrix has some particular structure, and you want to use a specialised more efficient matrix-vector product algorithm in prod
. For example this is what diffrax.WeaklyDiagonalControlTerm
does, as compared to just diffrax.ControlTerm
.
diffrax.ODETerm (AbstractTerm)
¤
A term representing \(f(t, y(t), args) \mathrm{d}t\). That is to say, the term appearing on the right hand side of an ODE, in which the control is time.
vector_field
should return some PyTree, with the same structure as the initial
state y0
, and with every leaf broadcastable to the equivalent leaf in y0
.
Example
vector_field = lambda t, y, args: -y
ode_term = ODETerm(vector_field)
diffeqsolve(ode_term, ...)
__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree])
¤
Arguments:
vector_field
: A callable representing the vector field. This callable takes three arguments(t, y, args)
.t
is a scalar representing the integration time.y
is the evolving state of the system.args
are any static arguments as passed todiffrax.diffeqsolve
.
diffrax.ControlTerm (AbstractTerm)
¤
A term representing the general case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field - control interaction is a matrix-vector product.
vector_field
and control
should both return PyTrees, both with the same
structure as the initial state y0
. Every dimension of control
is then
contracted against the last dimensions of vector_field
; that is to say if each
leaf of y0
has shape (y1, ..., yN)
, and the corresponding leaf of control
has shape (c1, ..., cM)
, then the corresponding leaf of vector_field
should
have shape (y1, ..., yN, c1, ..., cM)
.
A common special case is when y0
and control
are vector-valued, and
vector_field
is matrix-valued.
Example
control = UnsafeBrownianPath(shape=(2,), key=...)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)
Example
ts = jnp.array([1., 2., 2.5, 3.])
data = jnp.array([[0.1, 2.0],
[0.3, 1.5],
[1.0, 1.6],
[0.2, 1.1]])
control = LinearInterpolation(ts, data)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
cde_term = ControlTerm(vector_field, control)
diffeqsolve(cde_term, ...)
__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree], control: AbstractPath)
¤
Arguments:
vector_field
: A callable representing the vector field. This callable takes three arguments(t, y, args)
.t
is a scalar representing the integration time.y
is the evolving state of the system.args
are any static arguments as passed todiffrax.diffeqsolve
.control
: A callable representing the control. Should have anevaluate(t0, t1)
method. If usingdiffrax.ControlTerm.to_ode
then it should have aderivative(t)
method.
to_ode(self) -> ODETerm
¤
If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as
\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).
This method converts this ControlTerm
into the corresponding
diffrax.ODETerm
in this way.
diffrax.WeaklyDiagonalControlTerm (AbstractTerm)
¤
A term representing the case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field - control interaction is a matrix-vector product, and the matrix is square and diagonal. In this case we may represent the matrix as a vector of just its diagonal elements. The matrix-vector product may be calculated by pointwise multiplying this vector with the control; this is more computationally efficient than writing out the full matrix and then doing a full matrix-vector product.
Correspondingly, vector_field
and control
should both return PyTrees, and both
should have the same structure and leaf shape as the initial state y0
. These are
multiplied together pointwise.
Info
Why "weakly" diagonal? Consider the matrix representation of the vector field,
as a square diagonal matrix. In general, the (i,i)-th element may depending
upon any of the values of y
. It is only if the (i,i)-th element only depends
upon the i-th element of y
that the vector field is said to be "diagonal",
without the "weak". (This stronger property is useful in some SDE solvers.)
__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree], control: AbstractPath)
¤
Arguments:
vector_field
: A callable representing the vector field. This callable takes three arguments(t, y, args)
.t
is a scalar representing the integration time.y
is the evolving state of the system.args
are any static arguments as passed todiffrax.diffeqsolve
.control
: A callable representing the control. Should have anevaluate(t0, t1)
method. If usingdiffrax.ControlTerm.to_ode
then it should have aderivative(t)
method.
to_ode(self) -> ODETerm
¤
If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as
\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).
This method converts this ControlTerm
into the corresponding
diffrax.ODETerm
in this way.
diffrax.MultiTerm (AbstractTerm)
¤
Accumulates multiple terms into a single term.
Consider the SDE
\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)
This has two terms on the right hand side. It may be represented with a single term as
\(\mathrm{d}y(t) = [f(t, y(t)), g(t, y(t))] \cdot [\mathrm{d}t, \mathrm{d}w(t)]\)
whose vector field -- control interaction is a dot product.
MultiTerm
performs this transform. For simplicitly most differential equation
solvers (at least those built-in to Diffrax) accept just a single term, so this
transform is a necessary part of e.g. solving an SDE with both drift and diffusion.
__init__(self, *terms: AbstractTerm)
¤
Arguments:
*terms
: Any number ofdiffrax.AbstractTerm
s to combine.