# Interpolations¤

When solving controlled differential equations, it is relatively common for the control to be an interpolation of discrete data.

The following interpolation routines may be used to perform this interpolation.

Note

Missing data, represented as `NaN`

, can be handled here as well. (And if you are familiar with the problem of informative missingness, note that this can be handled as well: see Sections 3.5 and 3.6 of this paper.)

## References

The main two references for using interpolation with controlled differential equations are as follows.

Original neural CDE paper:

```
@article{kidger2020neuralcde,
author={Kidger, Patrick and Morrill, James and Foster, James and Lyons, Terry},
title={{N}eural {C}ontrolled {D}ifferential {E}quations for {I}rregular {T}ime {S}eries},
journal={Neural Information Processing Systems},
year={2020},
}
```

Investigating specifically the choice of interpolation scheme for CDEs:

```
@article{morrill2021cdeonline,
title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline {P}rediction {T}asks},
author={Morrill, James and Kidger, Patrick and Yang, Lingyi and Lyons, Terry},
journal={arXiv:2106.11028},
year={2021}
}
```

## How to pick an interpolation scheme

There are a few main types of interpolation provided here. For 99% of applications you will want either rectilinear or cubic interpolation, as follows.

- Do you need to make online predictions at inference time?
- Yes: Do you need to make a prediction continuously, or just every time you get the next piece of data?
- Continuously: Use rectilinear interpolation.
- At data: Might there be missing values in the data?
- Yes: Use rectilinear interpolation.
- No: Use Hermite cubic splines with backward differences.

- No: Use Hermite cubic splines with backward differences.

- Yes: Do you need to make a prediction continuously, or just every time you get the next piece of data?

Rectilinear interpolation can be obtained by combining `diffrax.rectilinear_interpolation`

and `diffrax.LinearInterpolation`

.

Hermite cubic splines with backward differences can be obtained by combining `diffrax.backward_hermite_coefficients`

and `diffrax.CubicInterpolation`

.

## Interpolation classes¤

The following are the main interpolation classes. Instances of these classes are suitable controls to pass to `diffrax.ControlTerm`

.

####
```
diffrax.LinearInterpolation (AbstractPath)
```

¤

Linearly interpolates some data `ys`

over the interval \([t_0, t_1]\) with knots
at `ts`

.

Warning

If using `LinearInterpolation`

as part of a `diffrax.ControlTerm`

, then the
vector field will make a jump every time one of the knots `ts`

is passed. If
using an adaptive step size controller such as `diffrax.PIDController`

,
then this means the controller should be informed about the jumps, so that it
can handle them appropriately:

```
ts = ...
interp = LinearInterpolation(ts=ts, ...)
term = ControlTerm(..., control=interp)
stepsize_controller = PIDController(..., jump_ts=ts)
```

#####
`t0: Scalar`

`dataclass-field`

¤

#####
`t1: Scalar`

`dataclass-field`

¤

#####
`__init__(self, ts: Array['times'], ys: PyTree[Array['times', ...]])`

¤

**Arguments:**

`ts`

: Some increasing collection of times.`ys`

: The value of the data at those times.

Note that if `ys`

has any missing data then you may wish to use
`diffrax.linear_interpolation`

or `diffrax.rectilinear_interpolation`

first to
interpolate over these.

#####
`evaluate(self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True) -> PyTree`

¤

Evaluate the linear interpolation.

**Arguments:**

`t0`

: Any point in \([t_0, t_1]\) to evaluate the interpolation at.`t1`

: If passed, then the increment from`t1`

to`t0`

is evaluated instead.`left`

: Across jump points: whether to treat the path as left-continuous or right-continuous. [In practice linear interpolation is always continuous except around`NaN`

s.]

FAQ

Note that we use \(t_0\) and \(t_1\) to refer to the overall interval, as
obtained via `instance.t0`

and `instance.t1`

. We use `t0`

and `t1`

to refer
to some subinterval of \([t_0, t_1]\). This is an API that is used for
consistency with the rest of the package, and just happens to be a little
confusing here.

**Returns:**

If `t1`

is not passed:

The interpolation of the data. Suppose \(t_j < t < t_{j+1}\), where \(t\) is `t0`

and \(t_j\) and \(t_{j+1}\) are some element of `ts`

as passed in `__init__`

.
Then the value returned is
\(y_j + (y_{j+1} - y_j)\frac{t - t_j}{t_{j+1} - t_j}\).

If `t1`

is passed:

As above, with \(t\) taken to be both `t0`

and `t1`

, and the increment between
them returned.

#####
`derivative(self, t: Scalar, left: bool = True) -> PyTree`

¤

Evaluate the derivative of the linear interpolation. Essentially equivalent
to `jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),))`

.

**Arguments:**

`t`

: Any point in \([t_0, t_1]\) to evaluate the derivative at.`left`

: Whether to obtain the left-derivative or right-derivative at that point.

**Returns:**

The derivative of the interpolation of the data. Suppose \(t_j < t < t_{j+1}\),
where \(t_j\) and \(t_{j+1}\) are some elements of `ts`

passed in `__init__`

. Then
the value returned is \(\frac{y_{j+1} - y_j}{t_{j+1} - t_j}\).

####
```
diffrax.CubicInterpolation (AbstractPath)
```

¤

Piecewise cubic spline interpolation over the interval \([t_0, t_1]\).

#####
`t0: Scalar`

`dataclass-field`

¤

#####
`t1: Scalar`

`dataclass-field`

¤

#####
`__init__(self, ts: Array['times'], coeffs: Tuple[PyTree['times - 1', ...], PyTree['times - 1', ...], PyTree['times - 1', ...], PyTree['times - 1', ...]])`

¤

**Arguments:**

`ts`

: Some increasing collection of times.`coeffs`

: The coefficients at all but the last time.

Any kind of spline (natural, ...) may be used; simply pass the appropriate coefficients.

In practice a good choice is typically "cubic Hermite splines with backward
differences", introduced in this paper. Such
coefficients can be obtained using `diffrax.backward_hermite_coefficients`

.

Letting `d, c, b, a = coeffs`

, then for all `t`

in the interval from `ts[i]`

to
`ts[i + 1]`

the interpolation is defined as

```
d[i] * (t - ts[i]) ** 3 + c[i] * (t - ts[i]) ** 2 + b[i] * (t - ts[i]) + a[i]
```

#####
`evaluate(self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True) -> PyTree`

¤

Evaluate the cubic interpolation.

**Arguments:**

`t0`

: Any point in \([t_0, t_1]\) to evaluate the interpolation at.`t1`

: If passed, then the increment from`t1`

to`t0`

is evaluated instead.`left`

: Across jump points: whether to treat the path as left-continuous or right-continuous. [In practice cubic interpolation is always continuous except around`NaN`

s.]

FAQ

Note that we use \(t_0\) and \(t_1\) to refer to the overall interval, as
obtained via `instance.t0`

and `instance.t1`

. We use `t0`

and `t1`

to refer
to some subinterval of \([t_0, t_1]\). This is an API that is used for
consistency with the rest of the package, and just happens to be a little
confusing here.

**Returns:**

If `t1`

is not passed:

The interpolation of the data at `t0`

.

If `t1`

is passed:

The increment between `t0`

and `t1`

.

#####
`derivative(self, t: Scalar, left: bool = True) -> PyTree`

¤

Evaluate the derivative of the cubic interpolation. Essentially equivalent
to `jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),))`

.

**Arguments:**

`t`

: Any point in \([t_0, t_1]\) to evaluate the derivative at.`left`

: Whether to obtain the left-derivative or right-derivative at that point. [In practice cubic interpolation is always continuously differentiable except around`NaN`

s.]

**Returns:**

The derivative of the interpolation of the data.

## Handling missing data¤

We would like `diffrax.LinearInterpolation`

to be able to handle missing data (represented as `NaN`

). The following can be used for this purpose.

####
`diffrax.linear_interpolation(ts: Array['times'], ys: PyTree['times', ...], *, fill_forward_nans_at_end: bool = False, replace_nans_at_start: Optional[PyTree[...]] = None) -> PyTree['times', ...]`

¤

Fill in any missing values via linear interpolation.

Any missing values in `ys`

(represented as `NaN`

) are filled in by looking at the
nearest non-`NaN`

values either side, and linearly interpolating.

This is often useful prior to using `diffrax.LinearInterpolation`

to create a
continuous path from discrete observations.

**Arguments:**

`ts`

: The time of each observation.`ys`

: The observations themselves. Should use`NaN`

to indicate those missing observations to interpolate over.`fill_forward_nans_at_end`

: By default`NaN`

values at the end (with no non-`NaN`

value after them) are left as`NaN`

s. If this is set then they will instead be filled in using the last non-`NaN`

value.`replace_nans_at_start`

: By default`NaN`

values at the start (with no non-`NaN`

value before them) are left as`NaN`

s. If this is passed then it will be used to fill in such`NaN`

values.

**Returns:**

As `ys`

, but with `NaN`

values filled in.

####
`diffrax.rectilinear_interpolation(ts: Array['times'], ys: PyTree['times', ...], replace_nans_at_start: Optional[PyTree[...]] = None) -> Tuple[Array['2 * times - 1'], PyTree['2 * times - 1', ...]]`

¤

Rectilinearly interpolates the input. This is a variant of linear interpolation that is particularly useful when using neural CDEs in a real-time scenario.

This is often useful prior to using `diffrax.LinearInterpolation`

to create a
continuous path from discrete observations, in real-time scenarios.

It is strongly recommended to have a read of the reference below if you are unfamiliar.

## Reference

```
@article{morrill2021cdeonline,
title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline
{P}rediction {T}asks},
author={Morrill, James and Kidger, Patrick and Yang, Lingyi and
Lyons, Terry},
journal={arXiv:2106.11028},
year={2021}
}
```

Example

Suppose `ts = [t0, t1, t2, t3]`

and `ys = [y0, y1, y2, y3]`

. Then rectilinearly
interpolating these produces `new_ts = [t0, t1, t1, t2, t2, t3, t3]`

and
`new_ys = [y0, y0, y1, y1, y2, y2, y3]`

.

This can be thought of as advancing time whilst keeping the data fixed, then keeping the data fixed whilst advancing time.

**Arguments:**

`ts`

: The time of each observation.`ys`

: The observations themselves. Should use`NaN`

to indicate those missing observations to interpolate over.`replace_nans_at_start`

: By default`NaN`

values at the start (with no non-`NaN`

value before them) are left as`NaN`

s. If this is passed then it will be used to fill in such`NaN`

values.

**Returns:**

A new version of both `ts`

and `ys`

, subject to rectilinear interpolation.

Example

Suppose we wish to use a rectilinearly interpolated control to drive a neural CDE. Then this should be done something like the following:

```
ts = jnp.array([0., 1., 1.5, 2.])
ys = jnp.array([5., 6., 5., 6.])
ts, ys = rectilinear_interpolation(ts, ys)
data = jnp.stack([ts, ys], axis=-1)
interp_ts = jnp.arange(7)
interp = LinearInterpolation(interp_ts, data)
```

Note how time and observations are stacked together as the data of the interpolation (as usual for a neural CDE), and how the interpolation times are something we are free to pick.

## Calculating coefficients¤

####
`diffrax.backward_hermite_coefficients(ts: Array['times'], ys: PyTree['times', ...], *, deriv0: Optional[PyTree[...]] = None, fill_forward_nans_at_end: bool = False, replace_nans_at_start: Optional[PyTree[...]] = None) -> Tuple[PyTree['times - 1', ...], PyTree['times - 1', ...], PyTree['times - 1', ...], PyTree['times - 1', ...]]`

¤

Interpolates the data with a cubic spline. Specifically, this calculates the coefficients for Hermite cubic splines with backward differences.

This is most useful prior to using `diffrax.CubicInterpolation`

to create a
smooth path from discrete observations.

## Reference

Hermite cubic splines with backward differences were introduced in this paper:

```
@article{morrill2021cdeonline,
title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline
{P}rediction {T}asks},
author={Morrill, James and Kidger, Patrick and Yang, Lingyi and
Lyons, Terry},
journal={arXiv:2106.11028},
year={2021}
}
```

**Arguments:**

`ts`

: The time of each observation.`ys`

: The observations themselves. Should use`NaN`

to indicate missing data.`deriv0`

: The derivative at`ts[0]`

. If not passed then a forward difference of`(ys[i] - ys[0]) / (ts[i] - ts[0])`

is used, where`i`

is the index of the first non-`NaN`

element of`ys`

.`fill_forward_nans_at_end`

: By default`NaN`

values at the end (with no non-`NaN`

value after them) are left as`NaN`

s. If this is set then they will instead be filled in using the last non-`NaN`

value prior to fitting the cubic spline.`replace_nans_at_start`

: By default`NaN`

values at the start (with no non-`NaN`

value before them) are left as`NaN`

s. If this is passed then it will be used to fill in such`NaN`

values.

**Returns:**

The coefficients of the Hermite cubic spline. If `ts`

has length \(T\) then the
coefficients will be of length \(T - 1\), covering each of the intervals from `ts[0]`

to `ts[1]`

, and `ts[1]`

to `ts[2]`

etc.