Skip to content

Root finders¤

Some differential equation solvers -- in particular implicit solvers -- have to solve an implicit root-finding problem at every step. Such differential equation solvers thus rely on a particular choice of root-finding subroutine, which is passed as a root_finder argument.

Optimistix is the JAX library for solving root-finding problems, so all root finders are subclasses of optimistix.AbstractRootFinder. Here's a quick example:

Example

import diffrax as dfx
import optimistix as optx

root_finder = optx.Newton(rtol=1e-8, atol=1e-8)
solver = dfx.Kvaerno5(root_finder=root_finder)
dfx.diffeqsolve(..., solver, ...)

In addition to the solvers provided by Optimistix, then Diffrax provides some additional differential-equation-specific functionality, namely diffrax.VeryChord and diffrax.with_stepsize_controller_tols. The former is a variation of the chord method that is slightly more efficient for most differential equation solvers. The latter sets the convergence tolerances of the root-finding algorithm to whatever tolerances were used with the adaptive stepsize controller (i.e. diffeqsolve(..., stepsize_controller=diffrax.PIDController(rtol=..., atol=...))).

As such the default root-finding algorithm for most solvers in Diffrax is with_stepsize_controller_tols(VeryChord)().


diffrax.VeryChord ¤

The Chord method of root finding.

As optimistix.Chord, except that in Runge--Kutta methods, the linearisation point is recomputed per-step and not per-stage. (This is computationally cheaper.)

Advanced notes

In terms of how this matches the Optimistix API, this is done by supporting the option self.init(..., options=dict(init_state=...)), in which case it will directly return the provided state instead of computing it. This makes it possible to manually call self.init at an earlier point around the desired linearisation point.

__init__(self, rtol: float, atol: float, norm: Callable[[PyTree], Shaped[Array, '']] = <function max_norm>, kappa: float = 0.01, linear_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=None)) ¤

Arguments:

  • rtol: Relative tolerance for terminating the solve.
  • atol: Absolute tolerance for terminating the solve.
  • norm: The norm used to determine the difference between two iterates in the convergence criteria. Should be any function PyTree -> Scalar, for example optimistix.max_norm.
  • kappa: A tolerance for the early convergence check.
  • linear_solver: The linear solver used to compute the Newton step.

diffrax.with_stepsize_controller_tols(cls: type[optimistix._root_find.AbstractRootFinder]) ¤

Wraps a root finding class to indicate that it should use the same tolerances as were provided to an adaptive stepsize controller.

Arguments:

  • cls: a subclass of optimistix.AbstractRootFinder.

Returns:

A wrapped version of cls that no longer accepts the atol, rtol or norm arguments, and will instead copy them from the adaptive step size controller.

Example

import diffrax as dfx
import optimistix as optx

root_finder = dfx.with_stepsize_controller_tols(optx.Chord)()
solver = dfx.Kvaerno5(root_finder=root_finder)
stepsize_controller = dfx.PIDController(rtol=1e-8, atol=1e-8)

dfx.diffeqsolve(..., solver=solver, stepsize_controller=stepsize_controller)