Skip to content

Brownian controls¤

SDEs are simulated using a Brownian motion as a control. (See the neural SDE example.)

diffrax.AbstractBrownianPath

diffrax.AbstractBrownianPath (AbstractPath) ¤

Abstract base class for all Brownian paths.

evaluate(self, t0: Union[float, int], t1: Union[float, int] = None, left: bool = True, use_levy: bool = False) -> ~_Control abstractmethod ¤

Samples a Brownian increment \(w(t_1) - w(t_0)\).

Each increment has distribution \(\mathcal{N}(0, t_1 - t_0)\).

Arguments:

  • t0: Any point in \([t_0, t_1]\) to evaluate the path at.
  • t1: If passed, then the increment from t1 to t0 is evaluated instead.
  • left: Ignored. (This determines whether to treat the path as left-continuous or right-continuous at any jump points, but Brownian motion has no jump points.)
  • use_levy: If True, the return type will be a LevyVal, which contains PyTrees of Brownian increments and their Lévy areas.

Returns:

If t1 is not passed:

The value of the Brownian motion at t0.

If t1 is passed:

The increment of the Brownian motion between t0 and t1.

¤

¤

diffrax.UnsafeBrownianPath (AbstractBrownianPath) ¤

Brownian simulation that is only suitable for certain cases.

This is a very quick way to simulate Brownian motion, but can only be used when all of the following are true:

  1. You are using a fixed step size controller. (Not an adaptive one.)

  2. You do not need to backpropagate through the differential equation.

  3. You do not need deterministic solutions with respect to key. (This implementation will produce different results based on fluctuations in floating-point arithmetic.)

Internally this operates by just sampling a fresh normal random variable over every interval, ignoring the correlation between samples exhibited in true Brownian motion. Hence the restrictions above. (They describe the general case for which the correlation structure isn't needed.)

Lévy Area

Can be initialised with levy_area set to diffrax.BrownianIncrement, or diffrax.SpaceTimeLevyArea. If levy_area=diffrax.SpaceTimeLevyArea, then it also computes space-time Lévy area H. This is an additional source of randomness required for certain stochastic Runge--Kutta solvers; see diffrax.AbstractSRK for more information.

An error will be thrown during tracing if Lévy area is required but is not available.

The choice here will impact the Brownian path, so even with the same key, the trajectory will be different depending on the value of levy_area.

__init__(self, shape: Union[tuple[int, ...], PyTree[jax._src.api.ShapeDtypeStruct]], key: PRNGKeyArray, levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]] = <class 'diffrax._custom_types.BrownianIncrement'>) ¤

Arguments:

  • shape: Should be a PyTree of jax.ShapeDtypeStructs, representing the shape, dtype, and PyTree structure of the output. For simplicity, shape can also just be a tuple of integers, describing the shape of a single JAX array. In that case the dtype is chosen to be the default floating-point dtype.
  • key: A random key.
  • levy_area: Whether to additionally generate Lévy area. This is required by some SDE solvers.

diffrax.VirtualBrownianTree (AbstractBrownianPath) ¤

Brownian simulation that discretises the interval [t0, t1] to tolerance tol.

Lévy Area

The parameter levy_area can be set to one of:

The choice of levy_area will impact the Brownian path, so even with the same key, the trajectory will be different depending on the value of levy_area.

Reference

Virtual Brownian trees were proposed in

@article{li2020scalable,
  title={Scalable gradients for stochastic differential equations},
  author={Li, Xuechen and Wong, Ting-Kam Leonard and Chen, Ricky T. Q. and
          Duvenaud, David},
  journal={International Conference on Artificial Intelligence and Statistics},
  year={2020}
}

The implementation here is an improvement on the above, in that it additionally simulates space-time and space-time-time Lévy areas, and exactly matches the distribution of the Brownian motion and its Lévy areas at all query times. This is due to the paper

@misc{jelinčič2024singleseed,
  title={Single-seed generation of Brownian paths and integrals
  for adaptive and high order SDE solvers},
  author={Andraž Jelinčič and James Foster and Patrick Kidger},
  year={2024},
  eprint={2405.06464},
  archivePrefix={arXiv},
  primaryClass={math.NA}
}

and Theorem 6.1.6 of

@phdthesis{foster2020a,
  publisher = {University of Oxford},
  school = {University of Oxford},
  title = {Numerical approximations for stochastic differential equations},
  author = {Foster, James M.},
  year = {2020}
}
__init__(self, t0: Union[float, int], t1: Union[float, int], tol: Union[float, int], shape: Union[tuple[int, ...], PyTree[jax._src.api.ShapeDtypeStruct]], key: PRNGKeyArray, levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]] = <class 'diffrax._custom_types.BrownianIncrement'>) ¤

Initialize self. See help(type(self)) for accurate signature.


Lévy areas¤

Brownian controls can return certain types of Lévy areas. These are iterated integrals of the Brownian motion, and are used by some SDE solvers. When a solver requires a Lévy area, it will have a minimal_levy_area attribute, which will always return an abstract Lévy area type, and it can accept any subclass of that type. The inheritance hierarchy is as follows:

AbstractBrownianIncrement
│   └── BrownianIncrement
└── AbstractSpaceTimeLevyArea
    │   └── SpaceTimeLevyArea
    └── AbstractSpaceTimeTimeLevyArea
            └── SpaceTimeTimeLevyArea
For example if solver.minimal_levy_area returns an AbstractSpaceTimeLevyArea, then the Brownian motion (which is either an UnsafeBrownianPath or a VirtualBrownianTree) should be initialized with levy_area=SpaceTimeLevyArea or levy_area=SpaceTimeTimeLevyArea. Note that for the Brownian motion, a concrete class must be used, not its abstract parent.

diffrax.AbstractBrownianIncrement ¤

AbstractBrownianIncrement()

__class__ ¤
__signature__ property readonly ¤
__base__ ¤
__base__ ¤

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__new__(/, mcls, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, dict_, strict: Union[bool, equinox._module.StrictConfig] = False, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__getattribute__(cls, item) ¤

Return getattr(self, name).

__setattr__(cls, item, value) ¤

Implement setattr(self, name, value).

__delattr__(self, name) ¤

Implement delattr(self, name).

__init__(self) ¤

Initialize self. See help(type(self)) for accurate signature.

__setattr__(self, name, value) ¤

Implement setattr(self, name, value).

__hash__(self) ¤

Return hash(self).

__eq__(self, other) -> Union[bool, numpy.bool, Bool[Array, '']] ¤

Return self==value.

__repr__(self) ¤

Return repr(self).

diffrax.BrownianIncrement (AbstractBrownianIncrement) ¤

BrownianIncrement(dt: PyTree[float, "BM"], W: PyTree[Shaped[ArrayLike, '?*bm'], "BM"])

dt: PyTree[float, "BM"] dataclass-field ¤
W: PyTree[Shaped[ArrayLike, '?*bm'], "BM"] dataclass-field ¤
__class__ ¤
__signature__ property readonly ¤
__base__ ¤
__base__ ¤

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__new__(/, mcls, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, dict_, strict: Union[bool, equinox._module.StrictConfig] = False, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__getattribute__(cls, item) ¤

Return getattr(self, name).

__setattr__(cls, item, value) ¤

Implement setattr(self, name, value).

__delattr__(self, name) ¤

Implement delattr(self, name).

__init__(self, dt: PyTree[float, "BM"], W: PyTree[Shaped[ArrayLike, '?*bm'], "BM"]) ¤

Initialize self. See help(type(self)) for accurate signature.

__setattr__(self, name, value) ¤

Implement setattr(self, name, value).

__hash__(self) ¤

Return hash(self).

__eq__(self, other) -> Union[bool, numpy.bool, Bool[Array, '']] ¤

Return self==value.

__repr__(self) ¤

Return repr(self).

diffrax.AbstractSpaceTimeLevyArea (AbstractBrownianIncrement) ¤

AbstractSpaceTimeLevyArea()

__class__ ¤
__signature__ property readonly ¤
__base__ ¤
__base__ ¤

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__new__(/, mcls, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, dict_, strict: Union[bool, equinox._module.StrictConfig] = False, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__getattribute__(cls, item) ¤

Return getattr(self, name).

__setattr__(cls, item, value) ¤

Implement setattr(self, name, value).

__delattr__(self, name) ¤

Implement delattr(self, name).

__init__(self) ¤

Initialize self. See help(type(self)) for accurate signature.

__setattr__(self, name, value) ¤

Implement setattr(self, name, value).

__hash__(self) ¤

Return hash(self).

__eq__(self, other) -> Union[bool, numpy.bool, Bool[Array, '']] ¤

Return self==value.

__repr__(self) ¤

Return repr(self).

diffrax.SpaceTimeLevyArea (AbstractSpaceTimeLevyArea) ¤

SpaceTimeLevyArea(dt: PyTree[float, "BM"], W: PyTree[Shaped[ArrayLike, '?bm'], "BM"], H: PyTree[Shaped[ArrayLike, '?bm'], "BM"])

dt: PyTree[float, "BM"] dataclass-field ¤
W: PyTree[Shaped[ArrayLike, '?*bm'], "BM"] dataclass-field ¤
H: PyTree[Shaped[ArrayLike, '?*bm'], "BM"] dataclass-field ¤
__class__ ¤
__signature__ property readonly ¤
__base__ ¤
__base__ ¤

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__new__(/, mcls, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, dict_, strict: Union[bool, equinox._module.StrictConfig] = False, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__getattribute__(cls, item) ¤

Return getattr(self, name).

__setattr__(cls, item, value) ¤

Implement setattr(self, name, value).

__delattr__(self, name) ¤

Implement delattr(self, name).

__init__(self, dt: PyTree[float, "BM"], W: PyTree[Shaped[ArrayLike, '?*bm'], "BM"], H: PyTree[Shaped[ArrayLike, '?*bm'], "BM"]) ¤

Initialize self. See help(type(self)) for accurate signature.

__setattr__(self, name, value) ¤

Implement setattr(self, name, value).

__hash__(self) ¤

Return hash(self).

__eq__(self, other) -> Union[bool, numpy.bool, Bool[Array, '']] ¤

Return self==value.

__repr__(self) ¤

Return repr(self).

diffrax.AbstractSpaceTimeTimeLevyArea (AbstractSpaceTimeLevyArea) ¤

AbstractSpaceTimeTimeLevyArea()

__class__ ¤
__signature__ property readonly ¤
__base__ ¤
__base__ ¤

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__new__(/, mcls, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, dict_, strict: Union[bool, equinox._module.StrictConfig] = False, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__getattribute__(cls, item) ¤

Return getattr(self, name).

__setattr__(cls, item, value) ¤

Implement setattr(self, name, value).

__delattr__(self, name) ¤

Implement delattr(self, name).

__init__(self) ¤

Initialize self. See help(type(self)) for accurate signature.

__setattr__(self, name, value) ¤

Implement setattr(self, name, value).

__hash__(self) ¤

Return hash(self).

__eq__(self, other) -> Union[bool, numpy.bool, Bool[Array, '']] ¤

Return self==value.

__repr__(self) ¤

Return repr(self).

diffrax.SpaceTimeTimeLevyArea (AbstractSpaceTimeTimeLevyArea) ¤

SpaceTimeTimeLevyArea(dt: PyTree[float, "BM"], W: PyTree[Shaped[ArrayLike, '?bm'], "BM"], H: PyTree[Shaped[ArrayLike, '?bm'], "BM"], K: PyTree[Shaped[ArrayLike, '?*bm'], "BM"])

dt: PyTree[float, "BM"] dataclass-field ¤
W: PyTree[Shaped[ArrayLike, '?*bm'], "BM"] dataclass-field ¤
H: PyTree[Shaped[ArrayLike, '?*bm'], "BM"] dataclass-field ¤
K: PyTree[Shaped[ArrayLike, '?*bm'], "BM"] dataclass-field ¤
__class__ ¤
__signature__ property readonly ¤
__base__ ¤
__base__ ¤

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__new__(/, mcls, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, namespace, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__instancecheck__(cls, instance) ¤

Override for isinstance(instance, cls).

__subclasscheck__(cls, subclass) ¤

Override for issubclass(subclass, cls).

register(cls, subclass) ¤

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

__new__(/, mcs, name, bases, dict_, strict: Union[bool, equinox._module.StrictConfig] = False, **kwargs) staticmethod ¤

Create and return a new object. See help(type) for accurate signature.

__call__(cls, *args, **kwargs) ¤

Call self as a function.

__getattribute__(cls, item) ¤

Return getattr(self, name).

__setattr__(cls, item, value) ¤

Implement setattr(self, name, value).

__delattr__(self, name) ¤

Implement delattr(self, name).

__init__(self, dt: PyTree[float, "BM"], W: PyTree[Shaped[ArrayLike, '?*bm'], "BM"], H: PyTree[Shaped[ArrayLike, '?*bm'], "BM"], K: PyTree[Shaped[ArrayLike, '?*bm'], "BM"]) ¤

Initialize self. See help(type(self)) for accurate signature.

__setattr__(self, name, value) ¤

Implement setattr(self, name, value).

__hash__(self) ¤

Return hash(self).

__eq__(self, other) -> Union[bool, numpy.bool, Bool[Array, '']] ¤

Return self==value.

__repr__(self) ¤

Return repr(self).