import jax import jax.numpy as jnp import optimistix as optx # Often import when doing scientific work jax.config.update("jax_enable_x64", True) def fn(y, args): a, b = y c = jnp.tanh(jnp.sum(b)) - a d = a**2 - jnp.sinh(b + 1) return c, d solver = optx.Newton(rtol=1e-8, atol=1e-8) y0 = (jnp.array(0.0), jnp.zeros((2, 2))) sol = optx.root_find(fn, solver, y0)
This has the following solution:
(Array(-0.85650715, dtype=float64), Array([[-0.32002086, -0.32002086], [-0.32002086, -0.32002086]], dtype=float64))
Which is indeed a root of
(Array(0., dtype=float64), Array([[1.11022302e-16, 1.11022302e-16], [1.11022302e-16, 1.11022302e-16]], dtype=float64))
Especially on tricker or mispecified problems, it may happen that the optimisation will fail. If Optimistix is unable to find the solution to a problem, it will produce an error, like so:
def does_not_have_root(y, _): # there is no value of y for which this equals zero. return 1 + y**2 y0 = jnp.array(1.0) optx.root_find(does_not_have_root, solver, y0)
--------------------------------------------------------------------------- XlaRuntimeError Traceback (most recent call last) Cell In, line 6 3 return 1 + y**2 5 y0 = jnp.array(1.) ----> 6 optx.root_find(does_not_have_root, solver, y0) XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the operator was not well-posed, and that the solver does not support this. If you are trying solve a linear least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular. If you *were* expecting this solver to work with this operator, then it may be because: (a) the operator is singular, and your code has a bug; or (b) the operator was nearly singular (i.e. it had a high condition number: `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from numerical instability issues; or (c) the operator is declared to exhibit a certain property (e.g. positive definiteness) that is does not actually satisfy. ------- This error occurred during the runtime of your JAX program. Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors. (This can be navigated using most of the the usual commands for the Python debugger: `u` and `d` to move through stack frames, the name of a variable to print its value, etc.) See also `https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.
If this happens, don't panic! It might be that your problem is misspecified (like here). Or it might be that you need to try a different solver -- some solvers are designed to get ultra-fast convergence on relatively "nice" problems, but don't try to handle messier problems. See the how to choose a solver for more details.
(For the advanced user: in this case the precise error message reflects the fact that the solver will have descended to
y=0, and then found that the Jacobian
d does_not_have_root / dy is zero at that point, so it cannot solve the linear system described by the Newton step.)
If you want to handle this error as part of your JAX program (instead of surfacing it as a Python exception), then you can pass
sol = optx.root_find(..., throw=False), and then access
sol.result to check the success or failure of the solve.