Nonlinear solvers¤
Some differential equation solvers -- in particular implicit solvers -- have to solve an implicit nonlinear problem at every step. Such differential equation solvers take an instance of a nonlinear solver as an argument.
diffrax.AbstractNonlinearSolver
diffrax.AbstractNonlinearSolver
¤
Abstract base class for all nonlinear root-finding algorithms.
Subclasses will be differentiable via the implicit function theorem.
__call__(self, fn: Callable, x: PyTree, args: PyTree, jac: Optional[~LU_Jacobian] = None) -> NonlinearSolution
¤
Find z
such that fn(z, args) = 0
.
Gradients will be computed with respect to args
. (And in particular not with
respect to either fn
or x
-- the latter has zero derivative by definition
anyway.)
Arguments:
fn
: A functionPyTree -> PyTree
to find the root of. (With input and output PyTrees of the same structure.)x
: An initial guess for the location of the root.args
: Arbitrary PyTree parameterisingfn
.jac
: As returned byself.jac(...)
. Many root finding algorithms use the Jacobiand(fn)/dx
as part of their iteration. Often they will recompute a Jacobian at every step (for example this is done in the "standard" Newton solver). In practice computing the Jacobian may be expensive, and it may be enough to use a single value for the Jacobian held constant throughout the iteration. (This is a quasi-Newton method known as the chord method.) For the former behaviour, do not passjac
. To get the latter behaviour, do passjac
.
Returns:
A NonlinearSolution
object, with attributes root
, num_steps
, result
.
root
(hopefully) solves fn(root, args) = 0
. num_steps
is the number of
steps taken in the nonlinear solver. result
is a status code indicating
whether the solver managed to converge or not.
jac(fn: Callable, x: PyTree, args: PyTree) -> ~LU_Jacobian
staticmethod
¤
Computes the LU decomposition of the Jacobian d(fn)/dx
.
Arguments as diffrax.AbstractNonlinearSolver.__call__
.
¤
¤
¤
diffrax.NewtonNonlinearSolver (AbstractNonlinearSolver)
¤
Newton's method for root-finding. (Also known as Newton--Raphson.)
Also supports the quasi-Newton chord method.
Info
If using this as part of a implicit ODE solver, then:
- An adaptive step size controller should be used (e.g.
diffrax.PIDController
). This will allow smaller steps to be made if the nonlinear solver fails to converge. - As a general rule, the values for
rtol
andatol
should be set to the same values as used for the adaptive step size controller. (And this will happen automatically by default.) - The value for
kappa
should usually be left alone.
Warning
Note that backpropagation through __call__
may not produce accurate values if
tolerate_nonconvergence=True
, as the backpropagation calculation implicitly
assumes that the forward pass converged.
__init__(self, rtol: Optional[Scalar] = None, atol: Optional[Scalar] = None, max_steps: Optional[int] = 10, kappa: Scalar = 0.01, norm: Callable = <function rms_norm>, tolerate_nonconvergence: bool = False)
¤
Arguments:
rtol
: The relative tolerance for determining convergence. Defaults to the samertol
as passed to an adaptive step controller if one is used.atol
: The absolute tolerance for determining convergence. Defaults to the sameatol
as passed to an adaptive step controller if one is used.max_steps
: The maximum number of steps allowed. If more than this are required then the iteration fails. Set toNone
to allow an arbitrary number of steps.kappa
: The kappa value for determining convergence.norm
: A functionPyTree -> Scalar
, which is called to determine the size of the current value. (Used in determining convergence.)tolerate_nonconvergence
: Whether to return an error code if the iteration fails to converge (or to silently pretend it was successful).
__call__(self, fn: Callable, x: PyTree, args: PyTree, jac: Optional[~LU_Jacobian] = None) -> NonlinearSolution
¤
Inherited from diffrax.AbstractNonlinearSolver.__call__
.
jac(fn: Callable, x: PyTree, args: PyTree) -> ~LU_Jacobian
¤
Inherited from diffrax.AbstractNonlinearSolver.jac
.