Skip to content

diffeqsolve¤

diffrax.diffeqsolve(terms: PyTree[AbstractTerm], solver: AbstractSolver, t0: Scalar, t1: Scalar, dt0: Optional[Scalar], y0: PyTree, args: Optional[PyTree] = None, *, saveat: SaveAt = SaveAt(t0=False,t1=True,ts=None,steps=False,dense=False,solver_state=False,controller_state=False,made_jump=False), stepsize_controller: AbstractStepSizeController = ConstantStepSize(compile_steps=False), adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(), discrete_terminating_event: Optional[AbstractDiscreteTerminatingEvent] = None, max_steps: Optional[int] = 4096, throw: bool = True, solver_state: Optional[PyTree] = None, controller_state: Optional[PyTree] = None, made_jump: Optional[bool] = None) -> Solution ¤

Solves a differential equation.

This function is the main entry point for solving all kinds of initial value problems, whether they are ODEs, SDEs, or CDEs.

The differential equation is integrated from t0 to t1.

See the Getting started page for example usage.

Main arguments:

These are the arguments most commonly used day-to-day.

  • terms: The terms of the differential equation. This specifies the vector field. (For non-ordinary differential equations (SDEs, CDEs), this also specifies the Brownian motion or the control.)
  • solver: The solver for the differential equation. See the guide on how to choose a solver.
  • t0: The start of the region of integration.
  • t1: The end of the region of integration.
  • dt0: The step size to use for the first step. If using fixed step sizes then this will also be the step size for all other steps. (Except the last one, which may be slightly smaller and clipped to t1.) If set as None then the initial step size will be determined automatically if possible.
  • y0: The initial value. This can be any PyTree of JAX arrays. (Or types that can be coerced to JAX arrays, like Python floats.)
  • args: Any additional arguments to pass to the vector field.
  • saveat: What times to save the solution of the differential equation. See diffrax.SaveAt. Defaults to just the last time t1. (Keyword-only argument.)
  • stepsize_controller: How to change the step size as the integration progresses. See the list of stepsize controllers. Defaults to using a fixed constant step size. (Keyword-only argument.)

Other arguments:

These arguments are infrequently used, and for most purposes you shouldn't need to understand these. All of these are keyword-only arguments.

  • adjoint: How to backpropagate (and compute forward-mode autoderivatives) of diffeqsolve. Defaults to discretise-then-optimise with recursive checkpointing, which is usually the best option for most problems. See the page on Adjoints for more information.

  • discrete_terminating_event: A discrete event at which to terminate the solve early. See the page on Events for more information.

  • max_steps: The maximum number of steps to take before quitting the computation unconditionally.

    Can also be set to None to allow an arbitrary number of steps, although this is incompatible with adjoint=RecursiveCheckpointAdjoint() (the default) and is incompatible with saveat=SaveAt(steps=True) or saveat=SaveAt(dense=True).

    Note that (a) compile times; and (b) backpropagation run times; will increase as max_steps increases. (Specifically, each time max_steps passes a power of 16.) You can reduce these times by using the smallest value of max_steps that is reasonable for your problem.

  • throw: Whether to raise an exception if the integration fails for any reason.

    If True then an integration failure will either raise a ValueError (when not using jax.jit) or print a warning message (when using jax.jit).

    If False then the returned solution object will have a result field indicating whether any failures occurred.

    Possible failures include for example hitting max_steps, or the problem becoming too stiff to integrate. (For most purposes these failures are unusual.)

    Note

    Note that when jax.vmap-ing a differential equation solve, then throw=True means that an exception will be raised if any batch element fails. You may prefer to set throw=False and inspect the result field of the returned solution object, to determine which batch elements succeeded and which failed.

  • solver_state: Some initial state for the solver. Generally obtained by SaveAt(solver_state=True) from a previous solve.

  • controller_state: Some initial state for the step size controller. Generally obtained by SaveAt(controller_state=True) from a previous solve.

  • made_jump: Whether a jump has just been made at t0. Used to update solver_state (if passed). Generally obtained by SaveAt(made_jump=True) from a previous solve.

Returns:

Returns a diffrax.Solution object specifying the solution to the differential equation.

Raises:

  • ValueError for bad inputs.
  • RuntimeError if throw=True and the integration fails (e.g. hitting the maximum number of steps).

Note

It is possible to have t1 < t0, in which case integration proceeds backwards in time.