Skip to content

Interpolations¤

When solving controlled differential equations, it is relatively common for the control to be an interpolation of discrete data.

The following interpolation routines may be used to perform this interpolation.

Note

Missing data, represented as NaN, can be handled here as well. (And if you are familiar with the problem of informative missingness, note that this can be handled as well: see Sections 3.5 and 3.6 of this paper.)

References

The main two references for using interpolation with controlled differential equations are as follows.

Original neural CDE paper:

@article{kidger2020neuralcde,
        author={Kidger, Patrick and Morrill, James and Foster, James and Lyons, Terry},
        title={{N}eural {C}ontrolled {D}ifferential {E}quations for {I}rregular {T}ime {S}eries},
        journal={Neural Information Processing Systems},
        year={2020},
}

Investigating specifically the choice of interpolation scheme for CDEs:

@article{morrill2021cdeonline,
        title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline {P}rediction {T}asks},
        author={Morrill, James and Kidger, Patrick and Yang, Lingyi and Lyons, Terry},
        journal={arXiv:2106.11028},
        year={2021}
}

How to pick an interpolation scheme

There are a few main types of interpolation provided here. For 99% of applications you will want either rectilinear or cubic interpolation, as follows.

  • Do you need to make online predictions at inference time?
    • Yes: Do you need to make a prediction continuously, or just every time you get the next piece of data?
      • Continuously: Use rectilinear interpolation.
      • At data: Might there be missing values in the data?
        • Yes: Use rectilinear interpolation.
        • No: Use Hermite cubic splines with backward differences.
    • No: Use Hermite cubic splines with backward differences.

Rectilinear interpolation can be obtained by combining diffrax.rectilinear_interpolation and diffrax.LinearInterpolation.

Hermite cubic splines with backward differences can be obtained by combining diffrax.backward_hermite_coefficients and diffrax.CubicInterpolation.


Interpolation classes¤

The following are the main interpolation classes. Instances of these classes are suitable controls to pass to diffrax.ControlTerm.

diffrax.LinearInterpolation (AbstractPath) ¤

Linearly interpolates some data ys over the interval \([t_0, t_1]\) with knots at ts.

Warning

If using LinearInterpolation as part of a diffrax.ControlTerm, then the vector field will make a jump every time one of the knots ts is passed. If using an adaptive step size controller such as diffrax.PIDController, then this means the controller should be informed about the jumps, so that it can handle them appropriately:

ts = ...
interp = LinearInterpolation(ts=ts, ...)
term = ControlTerm(..., control=interp)
stepsize_controller = PIDController(..., jump_ts=ts)
t0: Scalar dataclass-field ¤
t1: Scalar dataclass-field ¤
__init__(self, ts: Array['times'], ys: PyTree[Array['times', ...]]) ¤

Arguments:

  • ts: Some increasing collection of times.
  • ys: The value of the data at those times.

Note that if ys has any missing data then you may wish to use diffrax.linear_interpolation or diffrax.rectilinear_interpolation first to interpolate over these.

evaluate(self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True) -> PyTree ¤

Evaluate the linear interpolation.

Arguments:

  • t0: Any point in \([t_0, t_1]\) to evaluate the interpolation at.
  • t1: If passed, then the increment from t1 to t0 is evaluated instead.
  • left: Across jump points: whether to treat the path as left-continuous or right-continuous. [In practice linear interpolation is always continuous except around NaNs.]

FAQ

Note that we use \(t_0\) and \(t_1\) to refer to the overall interval, as obtained via instance.t0 and instance.t1. We use t0 and t1 to refer to some subinterval of \([t_0, t_1]\). This is an API that is used for consistency with the rest of the package, and just happens to be a little confusing here.

Returns:

If t1 is not passed:

The interpolation of the data. Suppose \(t_j < t < t_{j+1}\), where \(t\) is t0 and \(t_j\) and \(t_{j+1}\) are some element of ts as passed in __init__. Then the value returned is \(y_j + (y_{j+1} - y_j)\frac{t - t_j}{t_{j+1} - t_j}\).

If t1 is passed:

As above, with \(t\) taken to be both t0 and t1, and the increment between them returned.

derivative(self, t: Scalar, left: bool = True) -> PyTree ¤

Evaluate the derivative of the linear interpolation. Essentially equivalent to jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),)).

Arguments:

  • t: Any point in \([t_0, t_1]\) to evaluate the derivative at.
  • left: Whether to obtain the left-derivative or right-derivative at that point.

Returns:

The derivative of the interpolation of the data. Suppose \(t_j < t < t_{j+1}\), where \(t_j\) and \(t_{j+1}\) are some elements of ts passed in __init__. Then the value returned is \(\frac{y_{j+1} - y_j}{t_{j+1} - t_j}\).

diffrax.CubicInterpolation (AbstractPath) ¤

Piecewise cubic spline interpolation over the interval \([t_0, t_1]\).

t0: Scalar dataclass-field ¤
t1: Scalar dataclass-field ¤
__init__(self, ts: Array['times'], coeffs: Tuple[PyTree['times - 1', ...], PyTree['times - 1', ...], PyTree['times - 1', ...], PyTree['times - 1', ...]]) ¤

Arguments:

  • ts: Some increasing collection of times.
  • coeffs: The coefficients at all but the last time.

Any kind of spline (natural, ...) may be used; simply pass the appropriate coefficients.

In practice a good choice is typically "cubic Hermite splines with backward differences", introduced in this paper. Such coefficients can be obtained using diffrax.backward_hermite_coefficients.

Letting d, c, b, a = coeffs, then for all t in the interval from ts[i] to ts[i + 1] the interpolation is defined as

d[i] * (t - ts[i]) ** 3 + c[i] * (t - ts[i]) ** 2 + b[i] * (t - ts[i]) + a[i]

evaluate(self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True) -> PyTree ¤

Evaluate the cubic interpolation.

Arguments:

  • t0: Any point in \([t_0, t_1]\) to evaluate the interpolation at.
  • t1: If passed, then the increment from t1 to t0 is evaluated instead.
  • left: Across jump points: whether to treat the path as left-continuous or right-continuous. [In practice cubic interpolation is always continuous except around NaNs.]

FAQ

Note that we use \(t_0\) and \(t_1\) to refer to the overall interval, as obtained via instance.t0 and instance.t1. We use t0 and t1 to refer to some subinterval of \([t_0, t_1]\). This is an API that is used for consistency with the rest of the package, and just happens to be a little confusing here.

Returns:

If t1 is not passed:

The interpolation of the data at t0.

If t1 is passed:

The increment between t0 and t1.

derivative(self, t: Scalar, left: bool = True) -> PyTree ¤

Evaluate the derivative of the cubic interpolation. Essentially equivalent to jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),)).

Arguments:

  • t: Any point in \([t_0, t_1]\) to evaluate the derivative at.
  • left: Whether to obtain the left-derivative or right-derivative at that point. [In practice cubic interpolation is always continuously differentiable except around NaNs.]

Returns:

The derivative of the interpolation of the data.


Handling missing data¤

We would like diffrax.LinearInterpolation to be able to handle missing data (represented as NaN). The following can be used for this purpose.

diffrax.linear_interpolation(ts: Array['times'], ys: PyTree['times', ...], *, fill_forward_nans_at_end: bool = False, replace_nans_at_start: Optional[PyTree[...]] = None) -> PyTree['times', ...] ¤

Fill in any missing values via linear interpolation.

Any missing values in ys (represented as NaN) are filled in by looking at the nearest non-NaN values either side, and linearly interpolating.

This is often useful prior to using diffrax.LinearInterpolation to create a continuous path from discrete observations.

Arguments:

  • ts: The time of each observation.
  • ys: The observations themselves. Should use NaN to indicate those missing observations to interpolate over.
  • fill_forward_nans_at_end: By default NaN values at the end (with no non-NaN value after them) are left as NaNs. If this is set then they will instead be filled in using the last non-NaN value.
  • replace_nans_at_start: By default NaN values at the start (with no non-NaN value before them) are left as NaNs. If this is passed then it will be used to fill in such NaN values.

Returns:

As ys, but with NaN values filled in.

diffrax.rectilinear_interpolation(ts: Array['times'], ys: PyTree['times', ...], replace_nans_at_start: Optional[PyTree[...]] = None) -> Tuple[Array['2 * times - 1'], PyTree['2 * times - 1', ...]] ¤

Rectilinearly interpolates the input. This is a variant of linear interpolation that is particularly useful when using neural CDEs in a real-time scenario.

This is often useful prior to using diffrax.LinearInterpolation to create a continuous path from discrete observations, in real-time scenarios.

It is strongly recommended to have a read of the reference below if you are unfamiliar.

Reference
@article{morrill2021cdeonline,
        title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline
               {P}rediction {T}asks},
        author={Morrill, James and Kidger, Patrick and Yang, Lingyi and
                Lyons, Terry},
        journal={arXiv:2106.11028},
        year={2021}
}

Example

Suppose ts = [t0, t1, t2, t3] and ys = [y0, y1, y2, y3]. Then rectilinearly interpolating these produces new_ts = [t0, t1, t1, t2, t2, t3, t3] and new_ys = [y0, y0, y1, y1, y2, y2, y3].

This can be thought of as advancing time whilst keeping the data fixed, then keeping the data fixed whilst advancing time.

Arguments:

  • ts: The time of each observation.
  • ys: The observations themselves. Should use NaN to indicate those missing observations to interpolate over.
  • replace_nans_at_start: By default NaN values at the start (with no non-NaN value before them) are left as NaNs. If this is passed then it will be used to fill in such NaN values.

Returns:

A new version of both ts and ys, subject to rectilinear interpolation.

Example

Suppose we wish to use a rectilinearly interpolated control to drive a neural CDE. Then this should be done something like the following:

ts = jnp.array([0., 1., 1.5, 2.])
ys = jnp.array([5., 6., 5., 6.])
ts, ys = rectilinear_interpolation(ts, ys)
data = jnp.stack([ts, ys], axis=-1)
interp_ts = jnp.arange(7)
interp = LinearInterpolation(interp_ts, data)

Note how time and observations are stacked together as the data of the interpolation (as usual for a neural CDE), and how the interpolation times are something we are free to pick.


Calculating coefficients¤

diffrax.backward_hermite_coefficients(ts: Array['times'], ys: PyTree['times', ...], *, deriv0: Optional[PyTree[...]] = None, fill_forward_nans_at_end: bool = False, replace_nans_at_start: Optional[PyTree[...]] = None) -> Tuple[PyTree['times - 1', ...], PyTree['times - 1', ...], PyTree['times - 1', ...], PyTree['times - 1', ...]] ¤

Interpolates the data with a cubic spline. Specifically, this calculates the coefficients for Hermite cubic splines with backward differences.

This is most useful prior to using diffrax.CubicInterpolation to create a smooth path from discrete observations.

Reference

Hermite cubic splines with backward differences were introduced in this paper:

@article{morrill2021cdeonline,
        title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline
               {P}rediction {T}asks},
        author={Morrill, James and Kidger, Patrick and Yang, Lingyi and
                Lyons, Terry},
        journal={arXiv:2106.11028},
        year={2021}
}

Arguments:

  • ts: The time of each observation.
  • ys: The observations themselves. Should use NaN to indicate missing data.
  • deriv0: The derivative at ts[0]. If not passed then a forward difference of (ys[i] - ys[0]) / (ts[i] - ts[0]) is used, where i is the index of the first non-NaN element of ys.
  • fill_forward_nans_at_end: By default NaN values at the end (with no non-NaN value after them) are left as NaNs. If this is set then they will instead be filled in using the last non-NaN value prior to fitting the cubic spline.
  • replace_nans_at_start: By default NaN values at the start (with no non-NaN value before them) are left as NaNs. If this is passed then it will be used to fill in such NaN values.

Returns:

The coefficients of the Hermite cubic spline. If ts has length \(T\) then the coefficients will be of length \(T - 1\), covering each of the intervals from ts[0] to ts[1], and ts[1] to ts[2] etc.