The ability to create custom solvers is one of the most powerful things about Optimistix. This is a great playground for the advanced user.
Approach 1: mix-and-match using existing APIs.¤
Many abstract solvers have
descent fields. The first is a choice of line search, trust region, or learning rate. Optimistix uses a generalised notion of all three. The second is a choice of what it means to "move downhill": for example this could be
optimistix.SteepestDescent to use a local linear approximation, or
optimistix.NewtonDescent to use a local quadratic approximation.
See the searches and descents page for more details on this idea.
Here's quick demo of how to create a novel minimiser using this. This example uses a BFGS quasi-Newton approximation to the Hessian of a minimisation problem. This approximation is used to build a piecwise-linear dogleg-shaped descent path (interpolating between steepest descsent and Newton desscent). How far we move along this path is then determined by a trust region algorithm.
from collections.abc import Callable import optimistix as optx class MyNewMinimiser(optx.AbstractBFGS): rtol: float atol: float norm: Callable = optx.max_norm use_inverse: bool = False descent: optx.AbstractDescent = optx.DoglegDescent() search: optx.AbstractSearch = optx.ClassicalTrustRegion solver = MyNewMinimiser(rtol=1e-4, atol=1e-4)
import equinox as eqx # https://github.com/patrick-kidger/equinox import jax.tree_util as jtu import lineax as lx # https://github.com/google/lineax from jaxtyping import Array, PyTree # https://github.com/google/jaxtyping class NewtonDescentState(eqx.Module): newton: PyTree[Array] result: optx.RESULTS class NewtonDescent(optx.AbstractDescent): def init(self, y, f_info_struct): del f_info_struct # Dummy values of the right shape; unused. return NewtonDescentState(y, optx.RESULTS.successful) def query(self, y, f_info, state): del state if isinstance(f_info, optx.FunctionInfo.EvalGradHessianInv): newton = f_info.hessian_inv.mv(f_info.grad) result = optx.RESULTS.successful else: if isinstance(f_info, optx.FunctionInfo.EvalGradHessian): operator = f_info.hessian vector = f_info.grad elif isinstance(f_info, optx.FunctionInfo.ResidualJac): operator = f_info.jac vector = f_info.residual else: raise ValueError( "Cannot use a Newton descent with a solver that only evaluates the " "gradient, or only the function itself." ) out = lx.linear_solve(operator, vector) newton = out.value result = optx.RESULTS.promote(out.result) return NewtonDescentState(newton, result) def step(self, step_size, state): return jtu.tree_map(lambda x: -step_size * x, state.newton), state.result