Skip to content

Nonlinear solvers¤

Some differential equation solvers -- in particular implicit solvers -- have to solve an implicit nonlinear problem at every step. Such differential equation solvers take an instance of a nonlinear solver as an argument.


diffrax.AbstractNonlinearSolver ¤

Abstract base class for all nonlinear root-finding algorithms.

Subclasses will be differentiable via the implicit function theorem.

__call__(self, fn: Callable, x: PyTree, args: PyTree, jac: Optional[~LU_Jacobian] = None) -> NonlinearSolution ¤

Find z such that fn(z, args) = 0.

Gradients will be computed with respect to args. (And in particular not with respect to either fn or x -- the latter has zero derivative by definition anyway.)


  • fn: A function PyTree -> PyTree to find the root of. (With input and output PyTrees of the same structure.)
  • x: An initial guess for the location of the root.
  • args: Arbitrary PyTree parameterising fn.
  • jac: As returned by self.jac(...). Many root finding algorithms use the Jacobian d(fn)/dx as part of their iteration. Often they will recompute a Jacobian at every step (for example this is done in the "standard" Newton solver). In practice computing the Jacobian may be expensive, and it may be enough to use a single value for the Jacobian held constant throughout the iteration. (This is a quasi-Newton method known as the chord method.) For the former behaviour, do not pass jac. To get the latter behaviour, do pass jac.


A NonlinearSolution object, with attributes root, num_steps, result. root (hopefully) solves fn(root, args) = 0. num_steps is the number of steps taken in the nonlinear solver. result is a status code indicating whether the solver managed to converge or not.

jac(fn: Callable, x: PyTree, args: PyTree) -> ~LU_Jacobian staticmethod ¤

Computes the LU decomposition of the Jacobian d(fn)/dx.

Arguments as diffrax.AbstractNonlinearSolver.__call__.



diffrax.NewtonNonlinearSolver (AbstractNonlinearSolver) ¤

Newton's method for root-finding. (Also known as Newton--Raphson.)

Also supports the quasi-Newton chord method.


If using this as part of a implicit ODE solver, then:

  • An adaptive step size controller should be used (e.g. diffrax.PIDController). This will allow smaller steps to be made if the nonlinear solver fails to converge.
  • As a general rule, the values for rtol and atol should be set to the same values as used for the adaptive step size controller. (And this will happen automatically by default.)
  • The value for kappa should usually be left alone.


Note that backpropagation through __call__ may not produce accurate values if tolerate_nonconvergence=True, as the backpropagation calculation implicitly assumes that the forward pass converged.

__init__(self, rtol: Optional[Scalar] = None, atol: Optional[Scalar] = None, max_steps: Optional[int] = 10, kappa: Scalar = 0.01, norm: Callable = <function rms_norm>, tolerate_nonconvergence: bool = False) ¤


  • rtol: The relative tolerance for determining convergence. Defaults to the same rtol as passed to an adaptive step controller if one is used.
  • atol: The absolute tolerance for determining convergence. Defaults to the same atol as passed to an adaptive step controller if one is used.
  • max_steps: The maximum number of steps allowed. If more than this are required then the iteration fails. Set to None to allow an arbitrary number of steps.
  • kappa: The kappa value for determining convergence.
  • norm: A function PyTree -> Scalar, which is called to determine the size of the current value. (Used in determining convergence.)
  • tolerate_nonconvergence: Whether to return an error code if the iteration fails to converge (or to silently pretend it was successful).
__call__(self, fn: Callable, x: PyTree, args: PyTree, jac: Optional[~LU_Jacobian] = None) -> NonlinearSolution ¤
jac(fn: Callable, x: PyTree, args: PyTree) -> ~LU_Jacobian ¤