Skip to content

Solution¤

diffrax.Solution (AbstractPath) ¤

The solution to a differential equation.

Attributes:

  • t0: The start of the interval that the differential equation was solved over.
  • t1: The end of the interval that the differential equation was solved over.
  • ts: Some ordered collection of times. Might be None if no values were saved. (i.e. just diffeqsolve(..., saveat=SaveAt(dense=True)) is used.)
  • ys: The value of the solution at each of the times in ts. Might None if no values were saved.
  • stats: Statistics for the solve (number of steps etc.).
  • result: A diffrax.RESULTS specifying the success or cause of failure of the solve. A human-readable message is displayed if printed. No message means success!
  • solver_state: If saved, the final internal state of the numerical solver.
  • controller_state: If saved, the final internal state for the step size controller.
  • made_jump: If saved, the final internal state for the jump tracker.
  • event_mask: If using events, a boolean mask indicating which event triggered. This is a PyTree of bools, with the same PyTree stucture as the event condition functions. It will be all False if no events triggered; otherwise it will have precisely one True, corresponding to the event that triggered.

Note

If diffeqsolve(..., saveat=SaveAt(steps=True)) is set, then the ts and ys in the solution object will be padded with NaNs, out to the value of max_steps passed to diffrax.diffeqsolve.

This is because JAX demands that shapes be known statically ahead-of-time. As we do not know how many steps we will take until the solve is performed, we must allocate enough space for the maximum possible number of steps.

t0: Union[float, int] dataclass-field ¤
t1: Union[float, int] dataclass-field ¤
ts: Optional[PyTree[Real[Array, '?times'], "S"]] dataclass-field ¤
ys: Optional[PyTree[Shaped[Array, '?times ?*shape'], "S ..."]] dataclass-field ¤
stats: dict[str, Any] dataclass-field ¤
result: RESULTS dataclass-field ¤
solver_state: Optional[PyTree] dataclass-field ¤
controller_state: Optional[PyTree] dataclass-field ¤
made_jump: Optional[bool] dataclass-field ¤
evaluate(self, t0: Union[float, int], t1: Union[float, int] = None, left: bool = True) -> PyTree[Shaped[Array, '?*shape'], "Y"] ¤

If dense output was saved, then evaluate the solution at any point in the region of integration self.t0 to self.t1.

Arguments:

  • t0: The point to evaluate the solution at.
  • t1: If passed, then the increment from t0 to t1 is returned. (=evaluate(t1) - evaluate(t0))
  • left: When evaluating at a jump in the solution, whether to return the left-limit or the right-limit at that point.
derivative(self, t: Union[float, int], left: bool = True) -> PyTree[Shaped[Array, '?*shape'], "Y"] ¤

If dense output was saved, then calculate an approximation to the derivative of the solution at any point in the region of integration self.t0 to self.t1.

That is, letting \(y\) denote the solution over the interval [t0, t1], then this calculates an approximation to \(\frac{\mathrm{d}y}{\mathrm{d}t}\).

(This is not backpropagating through the differential equation -- that typically corresponds to e.g. \(\frac{\mathrm{d}y(t_1)}{\mathrm{d}y(t_0)}\).)

Example

For an ODE satisfying

\(\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t))\)

then this value is approximately equal to \(f(t, y(t))\).

Warning

This value is generally not very accurate. Differential equation solvers are usually designed to produce splines whose value is close to the true solution; not to produce splines whose derivative is close to the derivative of the true solution.

If you need accurate derivatives for the solution of an ODE, it is usually best to calculate vector_field(t, sol.evaluate(t), args). That is, to pay the extra computational cost of another vector field evaluation, in order to get a more accurate value.

Put precisely: this derivative method returns the derivative of the numerical solution, and not an approximation to the derivative of the true solution.

Arguments:

  • t: The point to calculate the derivative of the solution at.
  • left: When evaluating at a jump in the solution, whether to return the left-limit or the right-limit at that point.

diffrax.RESULTS ¤

An enumeration, with the following entries:

  • successful

  • max_steps_reached: The maximum number of solver steps was reached. Try increasing max_steps.

  • singular: The linear solver returned non-finite (NaN or inf) output. This usually means that the operator was not well-posed, and that the solver does not support this.

    If you are trying solve a linear least-squares problem then you should pass solver=AutoLinearSolver(well_posed=False). By default lineax.linear_solve assumes that the operator is square and nonsingular.

    If you were expecting this solver to work with this operator, then it may be because:

    (a) the operator is singular, and your code has a bug; or

    (b) the operator was nearly singular (i.e. it had a high condition number: jnp.linalg.cond(operator.as_matrix()) is large), and the solver suffered from numerical instability issues; or

    (c) the operator is declared to exhibit a certain property (e.g. positive definiteness) that is does not actually satisfy.

  • breakdown: A form of iterative breakdown has occured in the linear solve. Try using a different solver for this problem or increase restart if using GMRES.

  • stagnation: A stagnation in an iterative linear solve has occurred. Try increasing stagnation_iters or restart.

  • nonlinear_max_steps_reached: The maximum number of steps was reached in the nonlinear solver. The problem may not be solveable (e.g., a root-find on a function that has no roots), or you may need to increase max_steps.

  • nonlinear_divergence: Nonlinear solve diverged.

  • dt_min_reached: The minimum step size was reached in the differential equation solver.

  • event_occurred: Terminating differential equation solve because an event occurred.