diffrax.diffeqsolve(terms: PyTree[AbstractTerm], solver: AbstractSolver, t0: Scalar, t1: Scalar, dt0: Optional[Scalar], y0: PyTree, args: Optional[PyTree] = None, *, saveat: SaveAt = SaveAt(t0=False,t1=True,ts=None,steps=False,dense=False,solver_state=False,controller_state=False,made_jump=False), stepsize_controller: AbstractStepSizeController = ConstantStepSize(compile_steps=False), adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(), discrete_terminating_event: Optional[AbstractDiscreteTerminatingEvent] = None, max_steps: Optional[int] = 4096, throw: bool = True, solver_state: Optional[PyTree] = None, controller_state: Optional[PyTree] = None, made_jump: Optional[bool] = None) -> Solution
Solves a differential equation.
This function is the main entry point for solving all kinds of initial value problems, whether they are ODEs, SDEs, or CDEs.
The differential equation is integrated from
See the Getting started page for example usage.
These are the arguments most commonly used day-to-day.
terms: The terms of the differential equation. This specifies the vector field. (For non-ordinary differential equations (SDEs, CDEs), this also specifies the Brownian motion or the control.)
solver: The solver for the differential equation. See the guide on how to choose a solver.
t0: The start of the region of integration.
t1: The end of the region of integration.
dt0: The step size to use for the first step. If using fixed step sizes then this will also be the step size for all other steps. (Except the last one, which may be slightly smaller and clipped to
t1.) If set as
Nonethen the initial step size will be determined automatically if possible.
y0: The initial value. This can be any PyTree of JAX arrays. (Or types that can be coerced to JAX arrays, like Python floats.)
args: Any additional arguments to pass to the vector field.
saveat: What times to save the solution of the differential equation. See
diffrax.SaveAt. Defaults to just the last time
t1. (Keyword-only argument.)
stepsize_controller: How to change the step size as the integration progresses. See the list of stepsize controllers. Defaults to using a fixed constant step size. (Keyword-only argument.)
These arguments are infrequently used, and for most purposes you shouldn't need to understand these. All of these are keyword-only arguments.
adjoint: How to backpropagate (and compute forward-mode autoderivatives) of
diffeqsolve. Defaults to discretise-then-optimise with recursive checkpointing, which is usually the best option for most problems. See the page on Adjoints for more information.
discrete_terminating_event: A discrete event at which to terminate the solve early. See the page on Events for more information.
max_steps: The maximum number of steps to take before quitting the computation unconditionally.
Can also be set to
Noneto allow an arbitrary number of steps, although this is incompatible with
adjoint=RecursiveCheckpointAdjoint()(the default) and is incompatible with
Note that (a) compile times; and (b) backpropagation run times; will increase as
max_stepsincreases. (Specifically, each time
max_stepspasses a power of 16.) You can reduce these times by using the smallest value of
max_stepsthat is reasonable for your problem.
throw: Whether to raise an exception if the integration fails for any reason.
Truethen an integration failure will either raise a
ValueError(when not using
jax.jit) or print a warning message (when using
Falsethen the returned solution object will have a
resultfield indicating whether any failures occurred.
Possible failures include for example hitting
max_steps, or the problem becoming too stiff to integrate. (For most purposes these failures are unusual.)
Note that when
jax.vmap-ing a differential equation solve, then
throw=Truemeans that an exception will be raised if any batch element fails. You may prefer to set
throw=Falseand inspect the
resultfield of the returned solution object, to determine which batch elements succeeded and which failed.
solver_state: Some initial state for the solver. Generally obtained by
SaveAt(solver_state=True)from a previous solve.
controller_state: Some initial state for the step size controller. Generally obtained by
SaveAt(controller_state=True)from a previous solve.
made_jump: Whether a jump has just been made at
t0. Used to update
solver_state(if passed). Generally obtained by
SaveAt(made_jump=True)from a previous solve.
diffrax.Solution object specifying the solution to the differential
ValueErrorfor bad inputs.
throw=Trueand the integration fails (e.g. hitting the maximum number of steps).
It is possible to have
t1 < t0, in which case integration proceeds backwards