Root finders¤
Some differential equation solvers -- in particular implicit solvers -- have to solve an implicit root-finding problem at every step. Such differential equation solvers thus rely on a particular choice of root-finding subroutine, which is passed as a root_finder
argument.
Optimistix is the JAX library for solving root-finding problems, so all root finders are subclasses of optimistix.AbstractRootFinder
.
Here's a quick example:
Example
import diffrax as dfx
import optimistix as optx
root_finder = optx.Newton(rtol=1e-8, atol=1e-8)
solver = dfx.Kvaerno5(root_finder=root_finder)
dfx.diffeqsolve(..., solver, ...)
In addition to the solvers provided by Optimistix, then Diffrax provides some additional differential-equation-specific functionality, namely diffrax.VeryChord
and diffrax.with_stepsize_controller_tols
. The former is a variation of the chord method that is slightly more efficient for most differential equation solvers. The latter sets the convergence tolerances of the root-finding algorithm to whatever tolerances were used with the adaptive stepsize controller (i.e. diffeqsolve(..., stepsize_controller=diffrax.PIDController(rtol=..., atol=...))
).
As such the default root-finding algorithm for most solvers in Diffrax is with_stepsize_controller_tols(VeryChord)()
.
diffrax.VeryChord
¤
The Chord method of root finding.
As optimistix.Chord
, except that in Runge--Kutta methods, the linearisation point
is recomputed per-step and not per-stage. (This is computationally cheaper.)
Advanced notes
In terms of how this matches the Optimistix API, this is done by supporting the
option self.init(..., options=dict(init_state=...))
, in which case it will
directly return the provided state instead of computing it. This makes it
possible to manually call self.init
at an earlier point around the desired
linearisation point.
__init__(self, rtol: float, atol: float, norm: Callable[[PyTree], Shaped[Array, '']] = <function max_norm>, kappa: float = 0.01, linear_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=None))
¤
Arguments:
rtol
: Relative tolerance for terminating the solve.atol
: Absolute tolerance for terminating the solve.norm
: The norm used to determine the difference between two iterates in the convergence criteria. Should be any functionPyTree -> Scalar
, for exampleoptimistix.max_norm
.kappa
: A tolerance for the early convergence check.linear_solver
: The linear solver used to compute the Newton step.
diffrax.with_stepsize_controller_tols(cls: type[optimistix._root_find.AbstractRootFinder])
¤
Wraps a root finding class to indicate that it should use the same tolerances as were provided to an adaptive stepsize controller.
Arguments:
cls
: a subclass ofoptimistix.AbstractRootFinder
.
Returns:
A wrapped version of cls
that no longer accepts the atol
, rtol
or norm
arguments, and will instead copy them from the adaptive step size controller.
Example
import diffrax as dfx
import optimistix as optx
root_finder = dfx.with_stepsize_controller_tols(optx.Chord)()
solver = dfx.Kvaerno5(root_finder=root_finder)
stepsize_controller = dfx.PIDController(rtol=1e-8, atol=1e-8)
dfx.diffeqsolve(..., solver=solver, stepsize_controller=stepsize_controller)