import jax
import jax.numpy as jnp
import optimistix as optx
# Often import when doing scientific work
jax.config.update("jax_enable_x64", True)
def fn(y, args):
a, b = y
c = jnp.tanh(jnp.sum(b)) - a
d = a**2 - jnp.sinh(b + 1)
return c, d
solver = optx.Newton(rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.0), jnp.zeros((2, 2)))
sol = optx.root_find(fn, solver, y0)
This has the following solution:
print(sol.value)
Which is indeed a root of fn
:
print(fn(sol.value, args=None))
Handling errors¤
Especially on tricker or mispecified problems, it may happen that the optimisation will fail. If Optimistix is unable to find the solution to a problem, it will produce an error, like so:
def does_not_have_root(y, _):
# there is no value of y for which this equals zero.
return 1 + y**2
y0 = jnp.array(1.0)
optx.root_find(does_not_have_root, solver, y0)
If this happens, don't panic! It might be that your problem is misspecified (like here). Or it might be that you need to try a different solver -- some solvers are designed to get ultra-fast convergence on relatively "nice" problems, but don't try to handle messier problems. See the how to choose a solver for more details.
(For the advanced user: in this case the precise error message reflects the fact that the solver will have descended to y=0
, and then found that the Jacobian d does_not_have_root / dy
is zero at that point, so it cannot solve the linear system described by the Newton step.)
If you want to handle this error as part of your JAX program (instead of surfacing it as a Python exception), then you can pass sol = optx.root_find(..., throw=False)
, and then access sol.result
to check the success or failure of the solve.