Brownian controls¤
SDEs are simulated using a Brownian motion as a control. (See the neural SDE example.)
diffrax.AbstractBrownianPath
diffrax.AbstractBrownianPath (AbstractPath)
¤
Abstract base class for all Brownian paths.
evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) > PyTree[Array]
abstractmethod
¤
Samples a Brownian increment \(w(t_1)  w(t_0)\).
Each increment has distribution \(\mathcal{N}(0, t_1  t_0)\).
Arguments:
t0
: Start of interval.t1
: End of interval.left
: Ignored. (This determines whether to treat the path as leftcontinuous or rightcontinuous at any jump points, but Brownian motion has no jump points.)
Returns:
A pytree of JAX arrays corresponding to the increment \(w(t_1)  w(t_0)\).
Some subclasses may allow t1=None
, in which case just the value \(w(t_0)\) is
returned.
¤
¤
diffrax.UnsafeBrownianPath (AbstractBrownianPath)
¤
Brownian simulation that is only suitable for certain cases.
This is a very quick way to simulate Brownian motion, but can only be used when all of the following are true:

You are using a fixed step size controller. (Not an adaptive one.)

You do not need to backpropagate through the differential equation.

You do not need deterministic solutions with respect to
key
. (This implementation will produce different results based on fluctuations in floatingpoint arithmetic.)
Internally this operates by just sampling a fresh normal random variable over every interval, ignoring the correlation between samples exhibited in true Brownian motion. Hence the restrictions above. (They describe the general case for which the correlation structure isn't needed.)
__init__(self, shape: Union[Tuple[int, ...], PyTree[ShapeDtypeStruct]], key: jax.random.PRNGKey)
¤
Arguments:
shape
: Should be a PyTree ofjax.ShapeDtypeStruct
s, representing the shape, dtype, and PyTree structure of the output. For simplicity,shape
can also just be a tuple of integers, describing the shape of a single JAX array. In that case the dtype is chosen to be the default floatingpoint dtype.key
: A random key.
evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) > PyTree[Array]
¤
Implements diffrax.AbstractBrownianPath.evaluate
.
diffrax.VirtualBrownianTree (AbstractBrownianPath)
¤
Brownian simulation that discretises the interval [t0, t1]
to tolerance tol
,
and is piecewise quadratic at that discretisation.
Reference
@article{li2020scalable,
title={Scalable gradients for stochastic differential equations},
author={Li, Xuechen and Wong, TingKam Leonard and Chen, Ricky T. Q. and
Duvenaud, David},
journal={International Conference on Artificial Intelligence and Statistics},
year={2020}
}
(The implementation here is a slight improvement on the reference implementation, by being piecwise quadratic rather than piecewise linear. This corrects a small bias in the generated samples.)
__init__(self, t0: Scalar, t1: Scalar, tol: Scalar, shape: Union[Tuple[int, ...], PyTree[ShapeDtypeStruct]], key: jax.random.PRNGKey)
¤
Arguments:
t0
: The start of the interval the Brownian motion is defined over.t1
: The start of the interval the Brownian motion is defined over.tol
: The discretisation that[t0, t1]
is discretised to.shape
: Should be a PyTree ofjax.ShapeDtypeStruct
s, representing the shape, dtype, and PyTree structure of the output. For simplicity,shape
can also just be a tuple of integers, describing the shape of a single JAX array. In that case the dtype is chosen to be the default floatingpoint dtype.key
: A random key.
Info
If using this as part of an SDE solver, and you know (or have an estimate of) the
step sizes made in the solver, then you can optimise the computational efficiency
of the Virtual Brownian Tree by setting tol
to be just slightly smaller than the
step size of the solver.
The Brownian motion is defined to equal 0 at t0
.
evaluate(self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True) > PyTree[Array]
¤
Implements diffrax.AbstractBrownianPath.evaluate
.