Hybridising solvers¤
The ability to create custom solvers is one of the most powerful things about Optimistix. This is a great playground for the advanced user.
Approach 1: mix-and-match using existing APIs.¤
Many abstract solvers have search
and descent
fields. The first is a choice of line search, trust region, or learning rate. Optimistix uses a generalised notion of all three. The second is a choice of what it means to "move downhill": for example this could be optimistix.SteepestDescent
to use a local linear approximation, or optimistix.NewtonDescent
to use a local quadratic approximation.
See the searches and descents page for more details on this idea.
Here's quick demo of how to create a novel minimiser using this. This example uses a BFGS quasi-Newton approximation to the Hessian of a minimisation problem. This approximation is used to build a piecwise-linear dogleg-shaped descent path (interpolating between steepest descsent and Newton desscent). How far we move along this path is then determined by a trust region algorithm.
from collections.abc import Callable
import optimistix as optx
class MyNewMinimiser(optx.AbstractBFGS):
rtol: float
atol: float
norm: Callable = optx.max_norm
use_inverse: bool = False
descent: optx.AbstractDescent = optx.DoglegDescent()
search: optx.AbstractSearch = optx.ClassicalTrustRegion
solver = MyNewMinimiser(rtol=1e-4, atol=1e-4)
Approach 2: create whole-new solvers, searches, and descents.¤
This can be done by subclassing the relevant object; see the page on abstract base classes. For example, here's how we might describe a Newton descent direction.
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax.tree_util as jtu
import lineax as lx # https://github.com/google/lineax
from jaxtyping import Array, PyTree # https://github.com/google/jaxtyping
class NewtonDescentState(eqx.Module):
newton: PyTree[Array]
result: optx.RESULTS
class NewtonDescent(optx.AbstractDescent):
def init(self, y, f_info_struct):
del f_info_struct
# Dummy values of the right shape; unused.
return NewtonDescentState(y, optx.RESULTS.successful)
def query(self, y, f_info, state):
del state
if isinstance(f_info, optx.FunctionInfo.EvalGradHessianInv):
newton = f_info.hessian_inv.mv(f_info.grad)
result = optx.RESULTS.successful
else:
if isinstance(f_info, optx.FunctionInfo.EvalGradHessian):
operator = f_info.hessian
vector = f_info.grad
elif isinstance(f_info, optx.FunctionInfo.ResidualJac):
operator = f_info.jac
vector = f_info.residual
else:
raise ValueError(
"Cannot use a Newton descent with a solver that only evaluates the "
"gradient, or only the function itself."
)
out = lx.linear_solve(operator, vector)
newton = out.value
result = optx.RESULTS.promote(out.result)
return NewtonDescentState(newton, result)
def step(self, step_size, state):
return jtu.tree_map(lambda x: -step_size * x, state.newton), state.result