Abstract solvers¤
All of the solvers (both ODE and SDE solvers) implement the following interface specified by diffrax.AbstractSolver
.
The exact details of this interface are only really useful if you're using the Manual stepping interface or defining your own solvers; otherwise this is all just internal to the library.
Also see Extending Diffrax for more information on defining your own solvers.
In addition diffrax.AbstractSolver
has several subclasses that you can use to mark your custom solver as exhibiting particular behaviour.
diffrax.AbstractSolver
¤
Abstract base class for all differential equation solvers.
Subclasses should have a class-level attribute terms
, specifying the PyTree
structure of terms
in diffeqsolve(terms, ...)
.
order(self, terms: PyTree[AbstractTerm]) -> Optional[int]
¤
Order of the solver for solving ODEs.
strong_order(self, terms: PyTree[AbstractTerm]) -> Union[float, int]
¤
Strong order of the solver for solving SDEs.
error_order(self, terms: PyTree[AbstractTerm]) -> Union[float, int]
¤
Order of the error estimate used for adaptive stepping.
The default (slightly heuristic) implementation is as follows.
The error estimate is assumed to come from the difference of two methods. If
these two methods have orders p
and q
then the local order of the error
estimate is min(p, q) + 1
for an ODE and min(p, q) + 0.5
for an SDE.
- In the SDE case then we assume
p == q == solver.strong_order()
. - In the ODE case then we assume
p == q + 1 == solver.order()
. - We assume that non-SDE/ODE cases do not arise.
This is imperfect as these assumptions may not be true. In addition in the SDE case, then solvers will sometimes exhibit higher orders of convergence for specific noise types (see issue #47).
init(self, terms: PyTree[AbstractTerm], t0: Union[float, int], t1: Union[float, int], y0: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any]) -> ~_SolverState
abstractmethod
¤
Initialises any hidden state for the solver.
Arguments as diffrax.diffeqsolve
.
Returns:
The initial solver state, which should be used the first time step
is called.
step(self, terms: PyTree[AbstractTerm], t0: Union[float, int], t1: Union[float, int], y0: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any], solver_state: ~_SolverState, made_jump: bool) -> tuple[PyTree[Shaped[ArrayLike, '?*y'], "Y"], Optional[PyTree[Shaped[ArrayLike, '?*y'], "Y"]], dict[str, PyTree[Array]], ~_SolverState, RESULTS]
abstractmethod
¤
Make a single step of the solver.
Each step is made over the specified interval \([t_0, t_1]\).
Arguments:
terms
: The PyTree of terms representing the vector fields and controls.t0
: The start of the interval that the step is made over.t1
: The end of the interval that the step is made over.y0
: The current value of the solution att0
.args
: Any extra arguments passed to the vector field.solver_state
: Any evolving state for the solver itself, att0
.made_jump
: Whether there was a discontinuity in the vector field att0
. Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there are no jumps and for efficiency re-use information between steps; this indicates that a jump has just occurred and this assumption is not true.
Returns:
A tuple of several objects:
- The value of the solution at
t1
. - A local error estimate made during the step. (Used by adaptive step size
controllers to change the step size.) May be
None
if no estimate was made. - Some dictionary of information that is passed to the solver's interpolation
routine to calculate dense output. (Used with
SaveAt(ts=...)
orSaveAt(dense=...)
.) - The value of the solver state at
t1
. - An integer (corresponding to
diffrax.RESULTS
) indicating whether the step happened successfully, or if (unusually) it failed for some reason.
func(self, terms: PyTree[AbstractTerm], t0: Union[float, int], y0: PyTree[Shaped[ArrayLike, '?*y'], "Y"], args: PyTree[typing.Any]) -> PyTree[Shaped[ArrayLike, '?*vf'], "VF"]
abstractmethod
¤
Evaluate the vector field at a point. (This is unlike
diffrax.AbstractSolver.step
, which operates over an interval.)
For most operations differential equation solvers are interval-based, so this opertion should be used sparingly. This operation is needed for things like selecting an initial step size.
Arguments: As diffrax.diffeqsolve
Returns:
The evaluation of the vector field at t0
, y0
.
diffrax.AbstractImplicitSolver (AbstractSolver)
¤
Indicates that this is an implicit differential equation solver, and as such that it should take a root finder as an argument.
__init__(self)
¤
Initialize self. See help(type(self)) for accurate signature.
diffrax.AbstractAdaptiveSolver (AbstractSolver)
¤
Indicates that this solver provides error estimates, and that as such it may be used with an adaptive step size controller.
diffrax.AbstractItoSolver (AbstractSolver)
¤
Indicates that when used as an SDE solver that this solver will converge to the Itô solution.
diffrax.AbstractStratonovichSolver (AbstractSolver)
¤
Indicates that when used as an SDE solver that this solver will converge to the Stratonovich solution.
diffrax.AbstractWrappedSolver (AbstractSolver)
¤
Wraps another solver "transparently", in the sense that all isinstance
checks
will be forwarded on to the wrapped solver, e.g. when testing whether the solver is
implicit/adaptive/SDE-compatible/etc.
Inherit from this class if that is desired behaviour. (Do not inherit from this class if that is not desired behaviour.)
__init__(self)
¤
Initialize self. See help(type(self)) for accurate signature.
Abstract Runge--Kutta solvers¤
diffrax.AbstractRungeKutta (AbstractAdaptiveSolver)
¤
Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit Runge--Kutta methods, which have a different computational structure.)
Whilst this class can be subclassed directly, when defining your own Runge--Kutta
methods, it is usally better to work with diffrax.AbstractERK
,
diffrax.AbstractDIRK
, diffrax.AbstractSDIRK
,
diffrax.AbstractESDIRK
directly.
Subclasses should specify two class-level attributes. The first is tableau
, an
instance of diffrax.ButcherTableau
. The second is calculate_jacobian
, an
instance of diffrax.CalculateJacobian
.
diffrax.AbstractERK (AbstractRungeKutta)
¤
Abstract base class for all Explicit Runge--Kutta solvers.
Subclasses should include a class-level attribute tableau
, an instance of
diffrax.ButcherTableau
.
diffrax.AbstractDIRK (AbstractRungeKutta, AbstractImplicitSolver)
¤
Abstract base class for all Diagonal Implicit Runge--Kutta solvers.
Subclasses should include a class-level attribute tableau
, an instance of
diffrax.ButcherTableau
.
diffrax.AbstractSDIRK (AbstractDIRK)
¤
Abstract base class for all Singular Diagonal Implict Runge--Kutta solvers.
Subclasses should include a class-level attribute tableau
, an instance of
diffrax.ButcherTableau
.
diffrax.AbstractESDIRK (AbstractDIRK)
¤
Abstract base class for all Explicit Singular Diagonal Implicit Runge--Kutta solvers.
Subclasses should include a class-level attribute tableau
, an instance of
diffrax.ButcherTableau
.
diffrax.ButcherTableau
¤
The Butcher tableau for an explicit or diagonal Runge--Kutta method.
__init__(self, c: ndarray, b_sol: ndarray, b_error: ndarray, a_lower: tuple[numpy.ndarray, ...], a_diagonal: Optional[numpy.ndarray] = None, a_predictor: Optional[tuple[numpy.ndarray, ...]] = None, c1: float = 0.0)
¤
Arguments:
Let k
denote the number of stages of the solver.
a_lower
: the lower triangle (without the diagonal) of the Butcher tableau. Should be a tuple of NumPy arrays, corresponding to the rows of this lower triangle. The first array represents the should be of shape(1,)
. Each subsequent array should be of shape(2,)
,(3,)
etc. The final array should have shape(k - 1,)
.b_sol
: the linear combination of stages to take to produce the output at each step. Should be a NumPy array of shape(k,)
.b_error
: the linear combination of stages to take to produce the error estimate at each step. Should be a NumPy array of shape(k,)
. Note that this is not differenced againstb_sol
prior to evaluation. (i.e.b_error
gives the linear combination for producing the error estimate directly, not for producing some alternate solution that is compared against the main solution).c
: the time increments used in the Butcher tableau.a_diagonal
: optional. The diagonal of the Butcher tableau. Should beNone
or a NumPy array of shape(k,)
. Used for diagonal implicit Runge--Kutta methods only.a_predictor
: optional. Used in a similar way toa_lower
; specifies the linear combination of previous stages to use as a predictor for the solution to the implicit problem at that stage. See the developer documentation. Used for diagonal implicit Runge--Kutta methods only.
Whether the solver exhibits either the FSAL or SSAL properties is determined automatically.
diffrax.CalculateJacobian
¤
An enumeration of possible ways a Runga--Kutta method may wish to calculate a Jacobian.
never
: used for explicit Runga--Kutta methods.
every_stage
: the Jacobian is calculated once per stage. Used for DIRK methods.
first_stage
: the Jacobian is calculated once per step; in particular it is
calculated in the first stage and re-used for every subsequent stage in the
step. Used for SDIRK methods.
second_stage
: the Jacobian is calculated once per step; in particular it is
calculated in the second stage and re-used for every subsequent stage in the
step. Used for ESDIRK methods.
Abstract Stochastic Runge--Kutta (SRK) solvers¤
diffrax.AbstractSRK (AbstractSolver)
¤
A general Stochastic Runge-Kutta method.
This accepts terms
of the form
MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
.
Depending on the solver, the Brownian motion might need to generate
different types of Lévy areas, specified by the minimal_levy_area
attribute.
For example, the diffrax.ShARK
solver requires space-time Lévy area, so
it will have minimal_levy_area = AbstractSpaceTimeLevyArea
and the Brownian
motion must be initialised with levy_area=SpaceTimeLevyArea
.
Given the Stratonovich SDE \(dy(t) = f(t, y(t)) dt + g(t, y(t)) \circ dw(t)\)
We construct the SRK with \(s\) stages as follows:
\(y_{n+1} = y_n + h \Big(\sum_{j=1}^s b_j f_j \Big) + W_n \Big(\sum_{j=1}^s b^W_j g_j \Big) + H_n \Big(\sum_{j=1}^s b^H_j g_j \Big)\)
\(f_j = f(t_0 + c_j h , z_j)\)
\(g_j = g(t_0 + c_j h , z_j)\)
\(z_j = y_n + h \Big(\sum_{i=1}^{j-1} a_{j,i} f_i \Big) + W_n \Big(\sum_{i=1}^{j-1} a^W_{j,i} g_i \Big) + H_n \Big(\sum_{i=1}^{j-1} a^H_{j,i} g_i \Big)\)
where \(W_n = W_{t_n, t_{n+1}}\) is the increment of the Brownian motion and \(H_n = H_{t_n, t_{n+1}}\) is its corresponding space-time Lévy Area, defined as \(H_{s,t} = \frac{1}{t-s} \int_s^t (W_{s,r} - \frac{r-s}{t-s} W_{s,t}) \, dr\). A similar term can also be added for the space-time-time Lévy area, K, defined as \(K_{s,t} = \frac{1}{(t-s)^2} \int_s^t (W_{s,r} - \frac{r-s}{t-s} W_{s,t}) \left( \frac{t+s}{2} - r \right) \, dr\).
In the special case, when the SDE has additive noise, i.e. when g is independent of y (but can still depend on t), then the SDE can be written as \(dy(t) = f(t, y(t)) dt + g(t) \, dw(t)\), and we can simplify the above to
\(y_{n+1} = y_n + h \Big(\sum_{j=1}^s b_j k_j \Big) + g(t_n) \, (b^W \, W_n + b^H \, H_n)\)
\(f_j = f(t_n + c_j h , z_j)\)
\(z_j = y_n + h \Big(\sum_{i=1}^{j-1} a_{j,i} f_i \Big) + g(t_n) \, (a^W_j W_n + a^H_j H_n)\)
When g depends on t, we need to add a correction term to \(y_{n+1}\) of the form \((g(t_{n+1}) - g(t_n)) \, (\frac{1}{2} W_n - H_n)\).
The coefficients are provided in the diffrax.StochasticButcherTableau
.
In particular the coefficients \(b^W\), and \(a^W\) are provided in tableau.cfs_bm
,
as well as \(b^H\), \(a^H\), \(b^K\), and \(a^K\) if needed.
diffrax.StochasticButcherTableau
¤
A Butcher Tableau for Stochastic Runge-Kutta methods.
__init__(self, a: list[numpy.ndarray], b_sol: ndarray, b_error: Optional[numpy.ndarray], c: ndarray, coeffs_w: ~_Coeffs, coeffs_hh: Optional[~_Coeffs], coeffs_kk: Optional[~_Coeffs], ignore_stage_f: Optional[numpy.ndarray], ignore_stage_g: Optional[numpy.ndarray])
¤
The coefficients of a
diffrax.AbstractSRK
method.
See also the documentation for diffrax.AbstractSRK
for additional details on the
mathematical meaning of each of these arguments.
Arguments:
Let s
denote the number of stages of the solver.
a
: The lower triangle (without the diagonal) of the Butcher tableau for the drift term. Should be a tuple of NumPy arrays, corresponding to the rows of this lower triangle. The first array should be of shape(1,)
. Each subsequent array should be of shape(2,)
,(3,)
etc. The final array should have shape(s - 1,)
.b_sol
: The linear combination of drift stages to take to produce the output at each step. Should be a NumPy array of shape(s,)
.b_error
: The linear combination of stages to take to produce the error estimate at each step. Should be a NumPy array of shape(s,)
. Note that this is not differenced againstb_sol
prior to evaluation. (i.e.b_error
gives the linear combination for producing the error estimate directly, not for producing some alternate solution that is compared against the main solution).c
: The time increments used in the Butcher tableau. Should be a NumPy array of shape(s-1,)
, as the first stage has time increment 0.coeffs_w
: An instance ofAdditiveCoeffs
orGeneralCoeffs
, providing the coefficients of the Brownian motion increments.coeffs_hh
: An instance ofAdditiveCoeffs
orGeneralCoeffs
, providing the coefficients of the space-time Lévy area.coeffs_kk
: An instance ofAdditiveCoeffs
orGeneralCoeffs
, providing the coefficients of the space-time-time Lévy area.ignore_stage_f
: Optional. A NumPy array of lengths
of booleans. IfTrue
at stagej
, the vector field of the drift term will not be evaluated at stagej
.ignore_stage_g
: Optional. A NumPy array of lengths
of booleans. IfTrue
at stagej
, the diffusion vector field will not be evaluated at stagej
.
diffrax.AbstractFosterLangevinSRK (AbstractStratonovichSolver)
¤
Abstract class for Stochastic Runge Kutta methods specifically designed for Underdamped Langevin Diffusion of the form
where \(x(t), v(t) \in \mathbb{R}^d\) represent the position and velocity, \(w\) is a Brownian motion in \(\mathbb{R}^d\), \(f: \mathbb{R}^d \rightarrow \mathbb{R}\) is a potential function, and \(\gamma , u \in \mathbb{R}^{d \times d}\) are diagonal matrices governing the friction and the damping of the system.
Solvers which inherit from this class include diffrax.ALIGN
,
diffrax.ShOULD
, and diffrax.QUICSORT
.