# Terms¤

One of the advanced features of Diffrax is its *term* system. When we write down e.g. a stochastic differential equation

\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)

then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a *vector field* (\(f\) or \(g\)) and a *control* (\(\mathrm{d}t\) or \(\mathrm{d}w(t)\)). There is also an implicit assumption about how vector field and control interact: \(f\) and \(\mathrm{d}t\) interact as a vector-scalar product. \(g\) and \(\mathrm{d}w(t)\) interact as a matrix-vector product. (This interaction is always linear.)

"Terms" are thus the building blocks of differential equations. In Diffrax, the above SDE has its drift described by `diffrax.ODETerm`

and the diffusion described by a `diffrax.ControlTerm`

.

Example

As a simpler example, consider the ODE \(\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))\). Then this has vector field \(f\), control \(\mathrm{d}t\), and their interaction is a vector-scalar product. This can be described as a single `diffrax.ODETerm`

.

Example

Consider the pair of equations (as commonly arising from Hamiltonian systems):

\(\frac{\mathrm{d}x}{\mathrm{d}t}(t) = f(t, y(t)),\qquad\frac{\mathrm{d}y}{\mathrm{d}t}(t) = g(t, x(t))\)

These can be described as a 2-tuple of `diffrax.ODETerm`

`s.

The very first argument to `diffrax.diffeqsolve`

should be some PyTree of terms. This is interpreted by the solver in the appropriate way.

- For example
`diffrax.Euler`

expects a single term: it solves an ODE represented via`ODETerm(...)`

, or an SDE represented via`MultiTerm(ODETerm(...), ControlTerm(...))`

. - Meanwhile
`diffrax.SemiImplicitEuler`

solves the paired (Hamiltonian) system given in the example above, and expects a 2-tuple of terms representing each piece. - Some SDE-specific solvers (e.g.
`diffrax.StratonovichMilstein`

need to be able to see the distinction between the drift and diffusion, and expect a 2-tuple of terms representing the drift and diffusion respectively.

`diffrax.AbstractTerm`

####
```
diffrax.AbstractTerm
```

¤

Abstract base class for all terms.

Let \(y\) solve some differential equation with vector field \(f\) and control \(x\).

Let \(y\) have PyTree structure \(T\), let the output of the vector field have PyTree structure \(S\), and let \(x\) have PyTree structure \(U\), Then \(f : T \to S\) whilst the interaction \((f, x) \mapsto f \mathrm{d}x\) is a function \((S, U) \to T\).

#####
`vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree`

`abstractmethod`

¤

The vector field.

Represents a function \(f(t, y(t), args)\).

**Arguments:**

`t`

: the integration time.`y`

: the evolving state; a PyTree of structure \(T\).`args`

: any static arguments as passed to`diffrax.diffeqsolve`

.

**Returns:**

A PyTree of structure \(S\).

#####
`contr(self, t0: Scalar, t1: Scalar) -> PyTree`

`abstractmethod`

¤

The control.

Represents the \(\mathrm{d}t\) in an ODE, or the \(\mathrm{d}w(t)\) in an SDE, etc.

Most numerical ODE solvers work by making a step of length \(\Delta t = t_1 - t_0\). Likewise most numerical SDE solvers work by sampling some Brownian motion \(\Delta w \sim \mathcal{N}(0, t_1 - t_0)\).

Correspondingly a control is *not* defined at a point. Instead it is defined
over an interval \([t_0, t_1]\).

**Arguments:**

`t0`

: the start of the interval.`t1`

: the end of the interval.

**Returns:**

A PyTree of structure \(U\). For a control \(x\) then the result should represent \(x(t_1) - x(t_0)\).

#####
`prod(self, vf: PyTree, control: PyTree) -> PyTree`

`abstractmethod`

¤

Determines the interaction between vector field and control.

With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(f(t, y(t), args)\) and \(\Delta x(t)\).

**Arguments:**

`vf`

: The vector field evaluation; a PyTree of structure \(S\).`control`

: The control evaluated over an interval; a PyTree of structure \(U\).

**Returns:**

The interaction between the vector field and control; a PyTree of structure \(T\).

Note

This function must be bilinear.

#####
`vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree`

¤

The composition of `diffrax.AbstractTerm.vf`

and
`diffrax.AbstractTerm.prod`

.

With a solution \(y\) to a differential equation with vector field \(f\) and control \(x\), this computes \(f(t, y(t), args) \Delta x(t)\) given \(t\), \(y(t)\), \(args\), and \(\Delta x(t)\).

Its default implementation is simply

```
self.prod(self.vf(t, y, args), control)
```

This is offered as a special case that can be overridden when it is more efficient to do so.

Example

Consider when `vf`

computes a matrix-matrix product, and `prod`

computes a
matrix-vector product. Then doing a naive composition corresponds to a
(matrix-matrix)-vector product, which is less efficient than the
corresponding matrix-(matrix-vector) product. Overriding this method offers
a way to reclaim that efficiency.

Example

This is used extensively for efficiency when backpropagating via
`diffrax.BacksolveAdjoint`

.

**Arguments:**

`t`

: the integration time.`y`

: the evolving state; a PyTree of structure \(T\).`args`

: any static arguments as passed to`diffrax.diffeqsolve`

.`control`

: The control evaluated over an interval; a PyTree of structure \(U\).

**Returns:**

A PyTree of structure \(T\).

Note

This function must be linear in `control`

.

#####
`is_vf_expensive(self, t0: Scalar, t1: Scalar, y: Tuple[PyTree, PyTree, PyTree, PyTree], args: PyTree) -> bool`

¤

Specifies whether evaluating the vector field is "expensive", in the
specific sense that it is cheaper to evaluate `vf_prod`

twice than `vf`

once.

Some solvers use this to change their behaviour, so as to act more efficiently.

#### ¤

##### ¤

##### ¤

##### ¤

##### ¤

##### ¤

Note

You can create your own terms if appropriate: e.g. if a diffusion matrix has some particular structure, and you want to use a specialised more efficient matrix-vector product algorithm in `prod`

. For example this is what `diffrax.WeaklyDiagonalControlTerm`

does, as compared to just `diffrax.ControlTerm`

.

####
```
diffrax.ODETerm (AbstractTerm)
```

¤

A term representing \(f(t, y(t), args) \mathrm{d}t\). That is to say, the term appearing on the right hand side of an ODE, in which the control is time.

`vector_field`

should return some PyTree, with the same structure as the initial
state `y0`

, and with every leaf broadcastable to the equivalent leaf in `y0`

.

Example

```
vector_field = lambda t, y, args: -y
ode_term = ODETerm(vector_field)
diffeqsolve(ode_term, ...)
```

#####
`__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree])`

¤

**Arguments:**

`vector_field`

: A callable representing the vector field. This callable takes three arguments`(t, y, args)`

.`t`

is a scalar representing the integration time.`y`

is the evolving state of the system.`args`

are any static arguments as passed to`diffrax.diffeqsolve`

.

####
```
diffrax.ControlTerm (AbstractTerm)
```

¤

A term representing the general case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field - control interaction is a matrix-vector product.

`vector_field`

and `control`

should both return PyTrees, both with the same
structure as the initial state `y0`

. Every dimension of `control`

is then
contracted against the last dimensions of `vector_field`

; that is to say if each
leaf of `y0`

has shape `(y1, ..., yN)`

, and the corresponding leaf of `control`

has shape `(c1, ..., cM)`

, then the corresponding leaf of `vector_field`

should
have shape `(y1, ..., yN, c1, ..., cM)`

.

A common special case is when `y0`

and `control`

are vector-valued, and
`vector_field`

is matrix-valued.

Example

```
control = UnsafeBrownianPath(shape=(2,), key=...)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)
```

Example

```
ts = jnp.array([1., 2., 2.5, 3.])
data = jnp.array([[0.1, 2.0],
[0.3, 1.5],
[1.0, 1.6],
[0.2, 1.1]])
control = LinearInterpolation(ts, data)
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
cde_term = ControlTerm(vector_field, control)
diffeqsolve(cde_term, ...)
```

#####
`__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree], control: AbstractPath)`

¤

**Arguments:**

`vector_field`

: A callable representing the vector field. This callable takes three arguments`(t, y, args)`

.`t`

is a scalar representing the integration time.`y`

is the evolving state of the system.`args`

are any static arguments as passed to`diffrax.diffeqsolve`

.`control`

: A callable representing the control. Should have an`evaluate(t0, t1)`

method. If using`diffrax.ControlTerm.to_ode`

then it should have a`derivative(t)`

method.

#####
`to_ode(self) -> ODETerm`

¤

If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as

\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).

This method converts this `ControlTerm`

into the corresponding
`diffrax.ODETerm`

in this way.

####
```
diffrax.WeaklyDiagonalControlTerm (AbstractTerm)
```

¤

A term representing the case of \(f(t, y(t), args) \mathrm{d}x(t)\), in which the vector field - control interaction is a matrix-vector product, and the matrix is square and diagonal. In this case we may represent the matrix as a vector of just its diagonal elements. The matrix-vector product may be calculated by pointwise multiplying this vector with the control; this is more computationally efficient than writing out the full matrix and then doing a full matrix-vector product.

Correspondingly, `vector_field`

and `control`

should both return PyTrees, and both
should have the same structure and leaf shape as the initial state `y0`

. These are
multiplied together pointwise.

Info

Why "weakly" diagonal? Consider the matrix representation of the vector field,
as a square diagonal matrix. In general, the (i,i)-th element may depending
upon any of the values of `y`

. It is only if the (i,i)-th element only depends
upon the i-th element of `y`

that the vector field is said to be "diagonal",
without the "weak". (This stronger property is useful in some SDE solvers.)

#####
`__init__(self, vector_field: Callable[[Scalar, PyTree, PyTree], PyTree], control: AbstractPath)`

¤

**Arguments:**

`vector_field`

: A callable representing the vector field. This callable takes three arguments`(t, y, args)`

.`t`

is a scalar representing the integration time.`y`

is the evolving state of the system.`args`

are any static arguments as passed to`diffrax.diffeqsolve`

.`control`

: A callable representing the control. Should have an`evaluate(t0, t1)`

method. If using`diffrax.ControlTerm.to_ode`

then it should have a`derivative(t)`

method.

#####
`to_ode(self) -> ODETerm`

¤

If the control is differentiable then \(f(t, y(t), args) \mathrm{d}x(t)\) may be thought of as an ODE as

\(f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t\).

This method converts this `ControlTerm`

into the corresponding
`diffrax.ODETerm`

in this way.

####
```
diffrax.MultiTerm (AbstractTerm)
```

¤

Accumulates multiple terms into a single term.

Consider the SDE

\(\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)\)

This has two terms on the right hand side. It may be represented with a single term as

\(\mathrm{d}y(t) = [f(t, y(t)), g(t, y(t))] \cdot [\mathrm{d}t, \mathrm{d}w(t)]\)

whose vector field -- control interaction is a dot product.

`MultiTerm`

performs this transform. For simplicitly most differential equation
solvers (at least those built-in to Diffrax) accept just a single term, so this
transform is a necessary part of e.g. solving an SDE with both drift and diffusion.

#####
`__init__(self, *terms)`

¤

**Arguments:**

`*terms`

: Any number of`diffrax.AbstractTerm`

s to combine.