Skip to content

Interactively step through a solveยค

Sometimes you might want to perform an optimisation just one step at a time (or a few steps at a time), and perhaps do some other computations in between. A common example is when training a neural network, and looking to continually monitor performance of the model on a validation set.

One option is to repeatedly call e.g. optx.minimise(..., throw=False, max_steps=1). However if that seems inelegant/inefficient to you, then it is possible to use the solvers yourself directly.

Let's look at an example where we run an optimistix.Bisection search, and output the interval considered at each step.

Info

This is a relatively advanced API surface. In particular, no default arguments are provided, and all functions are assumed to return auxiliary information (which as in this example may be just None). See optimistix.AbstractRootFinder for details on each of the arguments.

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import optimistix as optx


# Seek `y` such that `y - tanh(y + 1) = 0`.
@eqx.filter_jit
def fn(y, args):
    out = y - jnp.tanh(y + 1)
    aux = None
    return out, aux


solver = optx.Bisection(rtol=1e-3, atol=1e-3)
# The initial guess for the solution
y = jnp.array(0)
# Any auxiliary information to pass to `fn`.
args = None
# The interval to search over. Required for `optx.Bisection`.
options = dict(lower=-1, upper=1)
# The shape+dtype of the output of `fn`
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
# Any Lineax tags describing the structure of the Jacobian matrix d(fn)/dy.
# (In this case it's just a 1x1 matrix, so these don't matter.)
tags = frozenset()


def solve(y, solver):
    # These arguments are always fixed throughout interactive solves.
    step = eqx.filter_jit(
        eqx.Partial(solver.step, fn=fn, args=args, options=options, tags=tags)
    )
    terminate = eqx.filter_jit(
        eqx.Partial(solver.terminate, fn=fn, args=args, options=options, tags=tags)
    )

    # Initial state before we start solving.
    state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
    done, result = terminate(y=y, state=state)

    # Alright, enough setup. Let's do the solve!
    while not done:
        print(f"Evaluating point {y} with value {fn(y, args)[0]}.")
        y, state, aux = step(y=y, state=state)
        done, result = terminate(y=y, state=state)
    if result != optx.RESULTS.successful:
        print(f"Oh no! Got error {result}.")
    y, _, _ = solver.postprocess(fn, y, aux, args, options, state, tags, result)
    print(f"Found solution {y} with value {fn(y, args)[0]}.")


solve(y, solver)
Evaluating point 0 with value -0.7615941762924194.
Evaluating point 0.5 with value -0.4051482081413269.
Evaluating point 0.75 with value -0.19137555360794067.
Evaluating point 0.875 with value -0.07904523611068726.
Evaluating point 0.9375 with value -0.021835267543792725.
Evaluating point 0.96875 with value 0.006998121738433838.
Evaluating point 0.953125 with value -0.007436692714691162.
Evaluating point 0.9609375 with value -0.0002237558364868164.
Evaluating point 0.96484375 with value 0.0033860206604003906.
Evaluating point 0.962890625 with value 0.0015808343887329102.
Evaluating point 0.9619140625 with value 0.0006784796714782715.
Found solution 0.96142578125 with value 0.00022733211517333984.

This example also highlights a detail of how many solvers work: whilst they keep searching for a better solution, they don't necessarily keep a copy of the best-so-far value around. Keeping this copy around would require extra memory, after all.

In this case, notice how one of the earlier points got a loss of -0.000223755836486816, which is actually slightly smaller than the loss from the final solution. (The returned solution is only guaranteed to be something satisfying the tolerance conditions.)

If we want to be sure of having the best-so-far value, then we can make a copy of it by using optimistix.BestSoFarRootFinder:

best_so_far_solver = optx.BestSoFarRootFinder(solver)
solve(y, best_so_far_solver)
Evaluating point 0 with value -0.7615941762924194.
Evaluating point 0.5 with value -0.4051482081413269.
Evaluating point 0.75 with value -0.19137555360794067.
Evaluating point 0.875 with value -0.07904523611068726.
Evaluating point 0.9375 with value -0.021835267543792725.
Evaluating point 0.96875 with value 0.006998121738433838.
Evaluating point 0.953125 with value -0.007436692714691162.
Evaluating point 0.9609375 with value -0.0002237558364868164.
Evaluating point 0.96484375 with value 0.0033860206604003906.
Evaluating point 0.962890625 with value 0.0015808343887329102.
Evaluating point 0.9619140625 with value 0.0006784796714782715.
Found solution 0.9609375 with value -0.0002237558364868164.