Interactively step through a solveยค
Sometimes you might want to perform an optimisation just one step at a time (or a few steps at a time), and perhaps do some other computations in between. A common example is when training a neural network, and looking to continually monitor performance of the model on a validation set.
One option is to repeatedly call e.g. optx.minimise(..., throw=False, max_steps=1)
. However if that seems inelegant/inefficient to you, then it is possible to use the solvers yourself directly.
Let's look at an example where we run an optimistix.Bisection
search, and output the interval considered at each step.
Info
This is a relatively advanced API surface. In particular, no default arguments are provided, and all functions are assumed to return aux
iliary information (which as in this example may be just None
). See optimistix.AbstractRootFinder
for details on each of the arguments.
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import optimistix as optx
# Seek `y` such that `y - tanh(y + 1) = 0`.
@eqx.filter_jit
def fn(y, args):
out = y - jnp.tanh(y + 1)
aux = None
return out, aux
solver = optx.Bisection(rtol=1e-3, atol=1e-3)
# The initial guess for the solution
y = jnp.array(0)
# Any auxiliary information to pass to `fn`.
args = None
# The interval to search over. Required for `optx.Bisection`.
options = dict(lower=-1, upper=1)
# The shape+dtype of the output of `fn`
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
# Any Lineax tags describing the structure of the Jacobian matrix d(fn)/dy.
# (In this case it's just a 1x1 matrix, so these don't matter.)
tags = frozenset()
def solve(y, solver):
# These arguments are always fixed throughout interactive solves.
step = eqx.filter_jit(
eqx.Partial(solver.step, fn=fn, args=args, options=options, tags=tags)
)
terminate = eqx.filter_jit(
eqx.Partial(solver.terminate, fn=fn, args=args, options=options, tags=tags)
)
# Initial state before we start solving.
state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
done, result = terminate(y=y, state=state)
# Alright, enough setup. Let's do the solve!
while not done:
print(f"Evaluating point {y} with value {fn(y, args)[0]}.")
y, state, aux = step(y=y, state=state)
done, result = terminate(y=y, state=state)
if result != optx.RESULTS.successful:
print(f"Oh no! Got error {result}.")
y, _, _ = solver.postprocess(fn, y, aux, args, options, state, tags, result)
print(f"Found solution {y} with value {fn(y, args)[0]}.")
solve(y, solver)
This example also highlights a detail of how many solvers work: whilst they keep searching for a better solution, they don't necessarily keep a copy of the best-so-far value around. Keeping this copy around would require extra memory, after all.
In this case, notice how one of the earlier points got a loss of -0.000223755836486816
, which is actually slightly smaller than the loss from the final solution. (The returned solution is only guaranteed to be something satisfying the tolerance conditions.)
If we want to be sure of having the best-so-far value, then we can make a copy of it by using optimistix.BestSoFarRootFinder
:
best_so_far_solver = optx.BestSoFarRootFinder(solver)
solve(y, best_so_far_solver)