Skip to content

Hybridising solvers¤

The ability to create custom solvers is one of the most powerful things about Optimistix. This is a great playground for the advanced user.

Approach 1: mix-and-match using existing APIs.¤

Many abstract solvers have search and descent fields. The first is a choice of line search, trust region, or learning rate. Optimistix uses a generalised notion of all three. The second is a choice of what it means to "move downhill": for example this could be optimistix.SteepestDescent to use a local linear approximation, or optimistix.NewtonDescent to use a local quadratic approximation.

See the searches and descents page for more details on this idea.

Here's quick demo of how to create a novel minimiser using this. This example uses a BFGS quasi-Newton approximation to the Hessian of a minimisation problem. This approximation is used to build a piecwise-linear dogleg-shaped descent path (interpolating between steepest descsent and Newton desscent). How far we move along this path is then determined by a trust region algorithm.

from collections.abc import Callable

import optimistix as optx


class MyNewMinimiser(optx.AbstractBFGS):
    rtol: float
    atol: float
    norm: Callable = optx.max_norm
    use_inverse: bool = False
    descent: optx.AbstractDescent = optx.DoglegDescent()
    search: optx.AbstractSearch = optx.ClassicalTrustRegion


solver = MyNewMinimiser(rtol=1e-4, atol=1e-4)

Approach 2: create whole-new solvers, searches, and descents.¤

This can be done by subclassing the relevant object; see the page on abstract base classes. For example, here's how we might describe a Newton descent direction.

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.tree_util as jtu
import lineax as lx  # https://github.com/google/lineax
from jaxtyping import Array, PyTree  # https://github.com/google/jaxtyping


class NewtonDescentState(eqx.Module):
    newton: PyTree[Array]
    result: optx.RESULTS


class NewtonDescent(optx.AbstractDescent):
    def init(self, y, f_info_struct):
        del f_info_struct
        # Dummy values of the right shape; unused.
        return NewtonDescentState(y, optx.RESULTS.successful)

    def query(self, y, f_info, state):
        del state
        if isinstance(f_info, optx.FunctionInfo.EvalGradHessianInv):
            newton = f_info.hessian_inv.mv(f_info.grad)
            result = optx.RESULTS.successful
        else:
            if isinstance(f_info, optx.FunctionInfo.EvalGradHessian):
                operator = f_info.hessian
                vector = f_info.grad
            elif isinstance(f_info, optx.FunctionInfo.ResidualJac):
                operator = f_info.jac
                vector = f_info.residual
            else:
                raise ValueError(
                    "Cannot use a Newton descent with a solver that only evaluates the "
                    "gradient, or only the function itself."
                )
            out = lx.linear_solve(operator, vector)
            newton = out.value
            result = optx.RESULTS.promote(out.result)
        return NewtonDescentState(newton, result)

    def step(self, step_size, state):
        return jtu.tree_map(lambda x: -step_size * x, state.newton), state.result