Brownian controls¤
SDEs are simulated using a Brownian motion as a control. (See the neural SDE example.)
diffrax.AbstractBrownianPath
diffrax.AbstractBrownianPath (AbstractPath)
¤
Abstract base class for all Brownian paths.
evaluate(self, t0: Union[float, int], t1: Union[float, int] = None, left: bool = True, use_levy: bool = False) > Union[PyTree[Array], diffrax._custom_types.LevyVal]
abstractmethod
¤
Samples a Brownian increment \(w(t_1)  w(t_0)\).
Each increment has distribution \(\mathcal{N}(0, t_1  t_0)\).
Arguments:
t0
: Any point in \([t_0, t_1]\) to evaluate the path at.t1
: If passed, then the increment fromt1
tot0
is evaluated instead.left
: Ignored. (This determines whether to treat the path as leftcontinuous or rightcontinuous at any jump points, but Brownian motion has no jump points.)use_levy
: If True, the return type will be aLevyVal
, which contains PyTrees of Brownian increments and their Levy areas.
Returns:
If t1
is not passed:
The value of the Brownian motion at t0
.
If t1
is passed:
The increment of the Brownian motion between t0
and t1
.
¤
¤
diffrax.UnsafeBrownianPath (AbstractBrownianPath)
¤
Brownian simulation that is only suitable for certain cases.
This is a very quick way to simulate Brownian motion, but can only be used when all of the following are true:

You are using a fixed step size controller. (Not an adaptive one.)

You do not need to backpropagate through the differential equation.

You do not need deterministic solutions with respect to
key
. (This implementation will produce different results based on fluctuations in floatingpoint arithmetic.)
Internally this operates by just sampling a fresh normal random variable over every interval, ignoring the correlation between samples exhibited in true Brownian motion. Hence the restrictions above. (They describe the general case for which the correlation structure isn't needed.)
Depending on the levy_area
argument, this can also be used to generate Levy area.
__init__(self, shape: Union[tuple[int, ...], PyTree[jax._src.api.ShapeDtypeStruct]], key: PRNGKeyArray, levy_area: Literal['', 'spacetime'] = '')
¤
Arguments:
shape
: Should be a PyTree ofjax.ShapeDtypeStruct
s, representing the shape, dtype, and PyTree structure of the output. For simplicity,shape
can also just be a tuple of integers, describing the shape of a single JAX array. In that case the dtype is chosen to be the default floatingpoint dtype.key
: A random key.levy_area
: Whether to additionally generate Levy area. This is required by some SDE solvers.
diffrax.VirtualBrownianTree (AbstractBrownianPath)
¤
Brownian simulation that discretises the interval [t0, t1]
to tolerance tol
.
Can be initialised with levy_area
set to ""
, or "spacetime"
.
If levy_area="space_time"
, then it also computes spacetime Lévy area H
.
This will impact the Brownian path, so even with the same key, the trajectory will
be different depending on the value of levy_area
.
Reference
Virtual Brownian trees were proposed in
@article{li2020scalable,
title={Scalable gradients for stochastic differential equations},
author={Li, Xuechen and Wong, TingKam Leonard and Chen, Ricky T. Q. and
Duvenaud, David},
journal={International Conference on Artificial Intelligence and Statistics},
year={2020}
}
The implementation here is an improvement on the above, in that it additionally simulates spacetime Levy area. This is due to Section 6.1 and Theorem 6.1.6 of
@phdthesis{foster2020a,
publisher = {University of Oxford},
school = {University of Oxford},
title = {Numerical approximations for stochastic differential equations},
author = {Foster, James M.},
year = {2020}
}
In addition, the implementation here is a further improvement on these by using an interpolation method which ensures the conditional 2nd moments are correct.
__init__(self, t0: Union[float, int], t1: Union[float, int], tol: Union[float, int], shape: Union[tuple[int, ...], PyTree[jax._src.api.ShapeDtypeStruct]], key: PRNGKeyArray, levy_area: Literal['', 'spacetime'] = '')
¤
Initialize self. See help(type(self)) for accurate signature.