Getting started¤

Optimistix is a JAX library for nonlinear solvers: root finding, minimisation, fixed points, and least squares.

Features include:

  • interoperable solvers: e.g. autoconvert root find problems to least squares problems, then solve using a minimsation algorithm.
  • modular optimisers: e.g. use a BFGS model function with a dogleg descent path with a trust region update.
  • using a PyTree as the state.
  • fast compilation and runtimes.
  • interoperability with Optax.
  • all the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support etc.


pip install optimistix

Requires Python 3.9+ and JAX 0.4.14+ and Equinox 0.11.0+.

Quick example¤

import jax.numpy as jnp
import optimistix as optx

# Let's solve the ODE dy/dt=tanh(y(t)) with the implicit Euler method.
# We need to find y1 s.t. y1 = y0 + tanh(y1)dt.

y0 = jnp.array(1.)
dt = jnp.array(0.1)

def fn(y, args):
    return y0 + jnp.tanh(y) * dt

solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.fixed_point(fn, solver, y0)
y1 = sol.value  # satisfies y1 == fn(y1)

Next steps¤

JAX ecosystem¤

