Skip to content

Abstract solvers¤

All of the solvers (both ODE and SDE solvers) implement the following interface specified by diffrax.AbstractSolver.

The exact details of this interface are only really useful if you're using the Manual stepping interface or defining your own solvers; otherwise this is all just internal to the library.

Also see Extending Diffrax for more information on defining your own solvers.

In addition diffrax.AbstractSolver has several subclasses that you can use to mark your custom solver as exhibiting particular behaviour.


diffrax.AbstractSolver ¤

Abstract base class for all differential equation solvers.

Subclasses should have a class-level attribute terms, specifying the PyTree structure of terms in diffeqsolve(terms, ...).

order(self, terms: PyTree[AbstractTerm]) -> Optional[int] ¤

Order of the solver for solving ODEs.

strong_order(self, terms: PyTree[AbstractTerm]) -> Union[float, int] ¤

Strong order of the solver for solving SDEs.

error_order(self, terms: PyTree[AbstractTerm]) -> Union[float, int] ¤

Order of the error estimate used for adaptive stepping.

The default (slightly heuristic) implementation is as follows.

The error estimate is assumed to come from the difference of two methods. If these two methods have orders p and q then the local order of the error estimate is min(p, q) + 1 for an ODE and min(p, q) + 0.5 for an SDE.

  • In the SDE case then we assume p == q == solver.strong_order().
  • In the ODE case then we assume p == q + 1 == solver.order().
  • We assume that non-SDE/ODE cases do not arise.

This is imperfect as these assumptions may not be true. In addition in the SDE case, then solvers will sometimes exhibit higher orders of convergence for specific noise types (see issue #47).

init(self, terms: PyTree[AbstractTerm], t0: Union[float, int], t1: Union[float, int], y0: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any]) -> ~_SolverState abstractmethod ¤

Initialises any hidden state for the solver.

Arguments as diffrax.diffeqsolve.

Returns:

The initial solver state, which should be used the first time step is called.

step(self, terms: PyTree[AbstractTerm], t0: Union[float, int], t1: Union[float, int], y0: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any], solver_state: ~_SolverState, made_jump: bool) -> tuple[PyTree[Shaped[ArrayLike, '?*y'], "Y"], Optional[PyTree[Shaped[ArrayLike, '?*y'], "Y"]], dict[str, PyTree[Array]], ~_SolverState, RESULTS] abstractmethod ¤

Make a single step of the solver.

Each step is made over the specified interval \([t_0, t_1]\).

Arguments:

  • terms: The PyTree of terms representing the vector fields and controls.
  • t0: The start of the interval that the step is made over.
  • t1: The end of the interval that the step is made over.
  • y0: The current value of the solution at t0.
  • args: Any extra arguments passed to the vector field.
  • solver_state: Any evolving state for the solver itself, at t0.
  • made_jump: Whether there was a discontinuity in the vector field at t0. Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there are no jumps and for efficiency re-use information between steps; this indicates that a jump has just occurred and this assumption is not true.

Returns:

A tuple of several objects:

  • The value of the solution at t1.
  • A local error estimate made during the step. (Used by adaptive step size controllers to change the step size.) May be None if no estimate was made.
  • Some dictionary of information that is passed to the solver's interpolation routine to calculate dense output. (Used with SaveAt(ts=...) or SaveAt(dense=...).)
  • The value of the solver state at t1.
  • An integer (corresponding to diffrax.RESULTS) indicating whether the step happened successfully, or if (unusually) it failed for some reason.
func(self, terms: PyTree[AbstractTerm], t0: Union[float, int], y0: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any]) -> PyTree[Shaped[ArrayLike, '?*vf'], "VF"] abstractmethod ¤

Evaluate the vector field at a point. (This is unlike diffrax.AbstractSolver.step, which operates over an interval.)

For most operations differential equation solvers are interval-based, so this opertion should be used sparingly. This operation is needed for things like selecting an initial step size.

Arguments: As diffrax.diffeqsolve

Returns:

The evaluation of the vector field at t0, y0.


diffrax.AbstractImplicitSolver (AbstractSolver) ¤

Indicates that this is an implicit differential equation solver, and as such that it should take a root finder as an argument.

__init__(self) ¤

Initialize self. See help(type(self)) for accurate signature.


diffrax.AbstractAdaptiveSolver (AbstractSolver) ¤

Indicates that this solver provides error estimates, and that as such it may be used with an adaptive step size controller.


diffrax.AbstractItoSolver (AbstractSolver) ¤

Indicates that when used as an SDE solver that this solver will converge to the Itô solution.


diffrax.AbstractStratonovichSolver (AbstractSolver) ¤

Indicates that when used as an SDE solver that this solver will converge to the Stratonovich solution.


diffrax.AbstractWrappedSolver (AbstractSolver) ¤

Wraps another solver "transparently", in the sense that all isinstance checks will be forwarded on to the wrapped solver, e.g. when testing whether the solver is implicit/adaptive/SDE-compatible/etc.

Inherit from this class if that is desired behaviour. (Do not inherit from this class if that is not desired behaviour.)

__init__(self) ¤

Initialize self. See help(type(self)) for accurate signature.


Abstract Runge--Kutta solvers¤

diffrax.AbstractRungeKutta (AbstractAdaptiveSolver) ¤

Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit Runge--Kutta methods, which have a different computational structure.)

Whilst this class can be subclassed directly, when defining your own Runge--Kutta methods, it is usally better to work with diffrax.AbstractERK, diffrax.AbstractDIRK, diffrax.AbstractSDIRK, diffrax.AbstractESDIRK directly.

Subclasses should specify two class-level attributes. The first is tableau, an instance of diffrax.ButcherTableau. The second is calculate_jacobian, an instance of diffrax.CalculateJacobian.

diffrax.AbstractERK (AbstractRungeKutta) ¤

Abstract base class for all Explicit Runge--Kutta solvers.

Subclasses should include a class-level attribute tableau, an instance of diffrax.ButcherTableau.

diffrax.AbstractDIRK (AbstractRungeKutta, AbstractImplicitSolver) ¤

Abstract base class for all Diagonal Implicit Runge--Kutta solvers.

Subclasses should include a class-level attribute tableau, an instance of diffrax.ButcherTableau.

diffrax.AbstractSDIRK (AbstractDIRK) ¤

Abstract base class for all Singular Diagonal Implict Runge--Kutta solvers.

Subclasses should include a class-level attribute tableau, an instance of diffrax.ButcherTableau.

diffrax.AbstractESDIRK (AbstractDIRK) ¤

Abstract base class for all Explicit Singular Diagonal Implicit Runge--Kutta solvers.

Subclasses should include a class-level attribute tableau, an instance of diffrax.ButcherTableau.

diffrax.ButcherTableau ¤

The Butcher tableau for an explicit or diagonal Runge--Kutta method.

__init__(self, c: ndarray, b_sol: ndarray, b_error: ndarray, a_lower: tuple[numpy.ndarray, ...], a_diagonal: Optional[numpy.ndarray] = None, a_predictor: Optional[tuple[numpy.ndarray, ...]] = None, c1: float = 0.0) ¤

Arguments:

Let k denote the number of stages of the solver.

  • a_lower: the lower triangle (without the diagonal) of the Butcher tableau. Should be a tuple of NumPy arrays, corresponding to the rows of this lower triangle. The first array represents the should be of shape (1,). Each subsequent array should be of shape (2,), (3,) etc. The final array should have shape (k - 1,).
  • b_sol: the linear combination of stages to take to produce the output at each step. Should be a NumPy array of shape (k,).
  • b_error: the linear combination of stages to take to produce the error estimate at each step. Should be a NumPy array of shape (k,). Note that this is not differenced against b_sol prior to evaluation. (i.e. b_error gives the linear combination for producing the error estimate directly, not for producing some alternate solution that is compared against the main solution).
  • c: the time increments used in the Butcher tableau.
  • a_diagonal: optional. The diagonal of the Butcher tableau. Should be None or a NumPy array of shape (k,). Used for diagonal implicit Runge--Kutta methods only.
  • a_predictor: optional. Used in a similar way to a_lower; specifies the linear combination of previous stages to use as a predictor for the solution to the implicit problem at that stage. See the developer documentation. Used for diagonal implicit Runge--Kutta methods only.

Whether the solver exhibits either the FSAL or SSAL properties is determined automatically.

diffrax.CalculateJacobian ¤

An enumeration of possible ways a Runga--Kutta method may wish to calculate a Jacobian.

never: used for explicit Runga--Kutta methods.

every_stage: the Jacobian is calculated once per stage. Used for DIRK methods.

first_stage: the Jacobian is calculated once per step; in particular it is calculated in the first stage and re-used for every subsequent stage in the step. Used for SDIRK methods.

second_stage: the Jacobian is calculated once per step; in particular it is calculated in the second stage and re-used for every subsequent stage in the step. Used for ESDIRK methods.


Abstract Stochastic Runge--Kutta (SRK) solvers¤

diffrax.AbstractSRK (AbstractSolver) ¤

A general Stochastic Runge-Kutta method.

This accepts terms of the form MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion)). Depending on the solver, the Brownian motion might need to generate different types of Lévy areas, specified by the minimal_levy_area attribute.

For example, the diffrax.ShARK solver requires space-time Lévy area, so it will have minimal_levy_area = AbstractSpaceTimeLevyArea and the Brownian motion must be initialised with levy_area=SpaceTimeLevyArea.

Given the Stratonovich SDE \(dy(t) = f(t, y(t)) dt + g(t, y(t)) \circ dw(t)\)

We construct the SRK with \(s\) stages as follows:

\(y_{n+1} = y_n + h \Big(\sum_{j=1}^s b_j f_j \Big) + W_n \Big(\sum_{j=1}^s b^W_j g_j \Big) + H_n \Big(\sum_{j=1}^s b^H_j g_j \Big)\)

\(f_j = f(t_0 + c_j h , z_j)\)

\(g_j = g(t_0 + c_j h , z_j)\)

\(z_j = y_n + h \Big(\sum_{i=1}^{j-1} a_{j,i} f_i \Big) + W_n \Big(\sum_{i=1}^{j-1} a^W_{j,i} g_i \Big) + H_n \Big(\sum_{i=1}^{j-1} a^H_{j,i} g_i \Big)\)

where \(W_n = W_{t_n, t_{n+1}}\) is the increment of the Brownian motion and \(H_n = H_{t_n, t_{n+1}}\) is its corresponding space-time Lévy Area, defined as \(H_{s,t} = \frac{1}{t-s} \int_s^t (W_{s,r} - \frac{r-s}{t-s} W_{s,t}) \, dr\). A similar term can also be added for the space-time-time Lévy area, K, defined as \(K_{s,t} = \frac{1}{(t-s)^2} \int_s^t (W_{s,r} - \frac{r-s}{t-s} W_{s,t}) \left( \frac{t+s}{2} - r \right) \, dr\).

In the special case, when the SDE has additive noise, i.e. when g is independent of y (but can still depend on t), then the SDE can be written as \(dy(t) = f(t, y(t)) dt + g(t) \, dw(t)\), and we can simplify the above to

\(y_{n+1} = y_n + h \Big(\sum_{j=1}^s b_j k_j \Big) + g(t_n) \, (b^W \, W_n + b^H \, H_n)\)

\(f_j = f(t_n + c_j h , z_j)\)

\(z_j = y_n + h \Big(\sum_{i=1}^{j-1} a_{j,i} f_i \Big) + g(t_n) \, (a^W_j W_n + a^H_j H_n)\)

When g depends on t, we need to add a correction term to \(y_{n+1}\) of the form \((g(t_{n+1}) - g(t_n)) \, (\frac{1}{2} W_n - H_n)\).

The coefficients are provided in the diffrax.StochasticButcherTableau. In particular the coefficients \(b^W\), and \(a^W\) are provided in tableau.cfs_bm, as well as \(b^H\), \(a^H\), \(b^K\), and \(a^K\) if needed.

diffrax.StochasticButcherTableau ¤

A Butcher Tableau for Stochastic Runge-Kutta methods.

__init__(self, a: list[numpy.ndarray], b_sol: ndarray, b_error: Optional[numpy.ndarray], c: ndarray, coeffs_w: ~_Coeffs, coeffs_hh: Optional[~_Coeffs], coeffs_kk: Optional[~_Coeffs], ignore_stage_f: Optional[numpy.ndarray], ignore_stage_g: Optional[numpy.ndarray]) ¤

The coefficients of a diffrax.AbstractSRK method.

See also the documentation for diffrax.AbstractSRK for additional details on the mathematical meaning of each of these arguments.

Arguments:

Let s denote the number of stages of the solver.

  • a: The lower triangle (without the diagonal) of the Butcher tableau for the drift term. Should be a tuple of NumPy arrays, corresponding to the rows of this lower triangle. The first array should be of shape (1,). Each subsequent array should be of shape (2,), (3,) etc. The final array should have shape (s - 1,).
  • b_sol: The linear combination of drift stages to take to produce the output at each step. Should be a NumPy array of shape (s,).
  • b_error: The linear combination of stages to take to produce the error estimate at each step. Should be a NumPy array of shape (s,). Note that this is not differenced against b_sol prior to evaluation. (i.e. b_error gives the linear combination for producing the error estimate directly, not for producing some alternate solution that is compared against the main solution).
  • c: The time increments used in the Butcher tableau. Should be a NumPy array of shape (s-1,), as the first stage has time increment 0.
  • coeffs_w: An instance of AdditiveCoeffs or GeneralCoeffs, providing the coefficients of the Brownian motion increments.
  • coeffs_hh: An instance of AdditiveCoeffs or GeneralCoeffs, providing the coefficients of the space-time Lévy area.
  • coeffs_kk: An instance of AdditiveCoeffs or GeneralCoeffs, providing the coefficients of the space-time-time Lévy area.
  • ignore_stage_f: Optional. A NumPy array of length s of booleans. If True at stage j, the vector field of the drift term will not be evaluated at stage j.
  • ignore_stage_g: Optional. A NumPy array of length s of booleans. If True at stage j, the diffusion vector field will not be evaluated at stage j.