# Hybridising solvers¤

The ability to create custom solvers is one of the most powerful things about Optimistix. This is a great playground for the advanced user.

## Approach 1: mix-and-match using existing APIs.¤

Many abstract solvers have search and descent fields. The first is a choice of line search, trust region, or learning rate. Optimistix uses a generalised notion of all three. The second is a choice of what it means to "move downhill": for example this could be optimistix.SteepestDescent to use a local linear approximation, or optimistix.NewtonDescent to use a local quadratic approximation.

See the searches and descents page for more details on this idea.

Here's quick demo of how to create a novel minimiser using this. This example uses a BFGS quasi-Newton approximation to the Hessian of a minimisation problem. This approximation is used to build a piecwise-linear dogleg-shaped descent path (interpolating between steepest descsent and Newton desscent). How far we move along this path is then determined by a trust region algorithm.

from collections.abc import Callable

import optimistix as optx

class MyNewMinimiser(optx.AbstractBFGS):
rtol: float
atol: float
norm: Callable = optx.max_norm
use_inverse: bool = False
descent: optx.AbstractDescent = optx.DoglegDescent()
search: optx.AbstractSearch = optx.ClassicalTrustRegion

solver = MyNewMinimiser(rtol=1e-4, atol=1e-4)


## Approach 2: create whole-new solvers, searches, and descents.¤

This can be done by subclassing the relevant object; see the page on abstract base classes. For example, here's how we might describe a Newton descent direction.

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.tree_util as jtu
import lineax as lx  # https://github.com/google/lineax
from jaxtyping import Array, PyTree  # https://github.com/google/jaxtyping

class NewtonDescentState(eqx.Module):
newton: PyTree[Array]
result: optx.RESULTS

class NewtonDescent(optx.AbstractDescent):
def init(self, y, f_info_struct):
del f_info_struct
# Dummy values of the right shape; unused.
return NewtonDescentState(y, optx.RESULTS.successful)

def query(self, y, f_info, state):
del state
if isinstance(f_info, optx.FunctionInfo.EvalGradHessianInv):
newton = f_info.hessian_inv.mv(f_info.grad)
result = optx.RESULTS.successful
else:
if isinstance(f_info, optx.FunctionInfo.EvalGradHessian):
operator = f_info.hessian
vector = f_info.grad
elif isinstance(f_info, optx.FunctionInfo.ResidualJac):
operator = f_info.jac
vector = f_info.residual
else:
raise ValueError(
"Cannot use a Newton descent with a solver that only evaluates the "
"gradient, or only the function itself."
)
out = lx.linear_solve(operator, vector)
newton = out.value
result = optx.RESULTS.promote(out.result)
return NewtonDescentState(newton, result)

def step(self, step_size, state):
return jtu.tree_map(lambda x: -step_size * x, state.newton), state.result