Solution¤
diffrax.Solution (AbstractPath)
¤
The solution to a differential equation.
Attributes:
t0
: The start of the interval that the differential equation was solved over.t1
: The end of the interval that the differential equation was solved over.ts
: Some ordered collection of times. Might beNone
if no values were saved. (i.e. justdiffeqsolve(..., saveat=SaveAt(dense=True))
is used.)ys
: The value of the solution at each of the times ints
. MightNone
if no values were saved.stats
: Statistics for the solve (number of steps etc.).result
: Adiffrax.RESULTS
specifying the success or cause of failure of the solve. A human-readable message is displayed if printed. No message means success!solver_state
: If saved, the final internal state of the numerical solver.controller_state
: If saved, the final internal state for the step size controller.made_jump
: If saved, the final internal state for the jump tracker.event_mask
: If using events, a boolean mask indicating which event triggered. This is a PyTree of bools, with the same PyTree stucture as the event condition functions. It will be allFalse
if no events triggered; otherwise it will have precisely oneTrue
, corresponding to the event that triggered.
Note
If diffeqsolve(..., saveat=SaveAt(steps=True))
is set, then the ts
and ys
in the solution object will be padded with NaN
s, out to the value of
max_steps
passed to diffrax.diffeqsolve
.
This is because JAX demands that shapes be known statically ahead-of-time. As we do not know how many steps we will take until the solve is performed, we must allocate enough space for the maximum possible number of steps.
t0: Union[float, int]
dataclass-field
¤
t1: Union[float, int]
dataclass-field
¤
ts: Optional[PyTree[Real[Array, '?times'], "S"]]
dataclass-field
¤
ys: Optional[PyTree[Shaped[Array, '?times ?*shape'], "S ..."]]
dataclass-field
¤
stats: dict[str, Any]
dataclass-field
¤
result: RESULTS
dataclass-field
¤
solver_state: Optional[PyTree]
dataclass-field
¤
controller_state: Optional[PyTree]
dataclass-field
¤
made_jump: Optional[bool]
dataclass-field
¤
evaluate(self, t0: Union[float, int], t1: Union[float, int] = None, left: bool = True) -> PyTree[Shaped[Array, '?*shape'], "Y"]
¤
If dense output was saved, then evaluate the solution at any point in the
region of integration self.t0
to self.t1
.
Arguments:
t0
: The point to evaluate the solution at.t1
: If passed, then the increment fromt0
tot1
is returned. (=evaluate(t1) - evaluate(t0)
)left
: When evaluating at a jump in the solution, whether to return the left-limit or the right-limit at that point.
derivative(self, t: Union[float, int], left: bool = True) -> PyTree[Shaped[Array, '?*shape'], "Y"]
¤
If dense output was saved, then calculate an approximation to the
derivative of the solution at any point in the region of integration self.t0
to self.t1
.
That is, letting \(y\) denote the solution over the interval [t0, t1]
, then
this calculates an approximation to \(\frac{\mathrm{d}y}{\mathrm{d}t}\).
(This is not backpropagating through the differential equation -- that typically corresponds to e.g. \(\frac{\mathrm{d}y(t_1)}{\mathrm{d}y(t_0)}\).)
Example
For an ODE satisfying
\(\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t))\)
then this value is approximately equal to \(f(t, y(t))\).
Warning
This value is generally not very accurate. Differential equation solvers are usually designed to produce splines whose value is close to the true solution; not to produce splines whose derivative is close to the derivative of the true solution.
If you need accurate derivatives for the solution of an ODE, it is usually
best to calculate vector_field(t, sol.evaluate(t), args)
. That is, to
pay the extra computational cost of another vector field evaluation, in
order to get a more accurate value.
Put precisely: this derivative
method returns the derivative of the
numerical solution, and not an approximation to the derivative of the
true solution.
Arguments:
t
: The point to calculate the derivative of the solution at.left
: When evaluating at a jump in the solution, whether to return the left-limit or the right-limit at that point.
diffrax.RESULTS
¤
An enumeration, with the following entries:
-
successful
-
max_steps_reached
: The maximum number of solver steps was reached. Try increasingmax_steps
. -
singular
: The linear solver returned non-finite (NaN or inf) output. This usually means that the operator was not well-posed, and that the solver does not support this.If you are trying solve a linear least-squares problem then you should pass
solver=AutoLinearSolver(well_posed=False)
. By defaultlineax.linear_solve
assumes that the operator is square and nonsingular.If you were expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
jnp.linalg.cond(operator.as_matrix())
is large), and the solver suffered from numerical instability issues; or(c) the operator is declared to exhibit a certain property (e.g. positive definiteness) that is does not actually satisfy.
-
breakdown
: A form of iterative breakdown has occured in the linear solve. Try using a different solver for this problem or increaserestart
if using GMRES. -
stagnation
: A stagnation in an iterative linear solve has occurred. Try increasingstagnation_iters
orrestart
. -
nonlinear_max_steps_reached
: The maximum number of steps was reached in the nonlinear solver. The problem may not be solveable (e.g., a root-find on a function that has no roots), or you may need to increasemax_steps
. -
nonlinear_divergence
: Nonlinear solve diverged. -
dt_min_reached
: The minimum step size was reached in the differential equation solver. -
event_occurred
: Terminating differential equation solve because an event occurred.