Symbolic Regressionยค
This example combines neural differential equations with regularised evolution to discover the equations
\(\frac{\mathrm{d} x}{\mathrm{d} t}(t) = \frac{y(t)}{1 + y(t)}\)
\(\frac{\mathrm{d} y}{\mathrm{d} t}(t) = \frac{-x(t)}{1 + x(t)}\)
directly from data.
References:
This example appears as an example in Section 6.1 of:
@phdthesis{kidger2021on,
title={{O}n {N}eural {D}ifferential {E}quations},
author={Patrick Kidger},
year={2021},
school={University of Oxford},
}
Whilst drawing heavy inspiration from:
@inproceedings{cranmer2020discovering,
title={{D}iscovering {S}ymbolic {M}odels from {D}eep {L}earning with {I}nductive
{B}iases},
author={Cranmer, Miles and Sanchez Gonzalez, Alvaro and Battaglia, Peter and
Xu, Rui and Cranmer, Kyle and Spergel, David and Ho, Shirley},
booktitle={Advances in Neural Information Processing Systems},
publisher={Curran Associates, Inc.},
year={2020},
}
@software{cranmer2020pysr,
title={PySR: Fast \& Parallelized Symbolic Regression in Python/Julia},
author={Miles Cranmer},
publisher={Zenodo},
url={http://doi.org/10.5281/zenodo.4041459},
year={2020},
}
This example is available as a Jupyter notebook here.
import tempfile
from typing import List
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import optax # https://github.com/deepmind/optax
import pysr # https://github.com/MilesCranmer/PySR
import sympy
# Note that PySR, which we use for symbolic regression, uses Julia as a backend.
# You'll need to install a recent version of Julia if you don't have one.
# (And can get funny errors if you have a too-old version of Julia already.)
# You may also need to restart Python after running `pysr.install()` the first time.
pysr.silence_julia_warning()
pysr.install(quiet=True)
Now for a bunch of helpers. We'll use these in a moment; skip over them for now.
def quantise(expr, quantise_to):
if isinstance(expr, sympy.Float):
return expr.func(round(float(expr) / quantise_to) * quantise_to)
elif isinstance(expr, sympy.Symbol):
return expr
else:
return expr.func(*[quantise(arg, quantise_to) for arg in expr.args])
class SymbolicFn(eqx.Module):
fn: callable
parameters: jnp.ndarray
def __call__(self, x):
# Dummy batch/unbatching. PySR assumes its JAX'd symbolic functions act on
# tensors with a single batch dimension.
return jnp.squeeze(self.fn(x[None], self.parameters))
class Stack(eqx.Module):
modules: List[eqx.Module]
def __call__(self, x):
return jnp.stack([module(x) for module in self.modules], axis=-1)
def expr_size(expr):
return sum(expr_size(v) for v in expr.args) + 1
def _replace_parameters(expr, parameters, i_ref):
if isinstance(expr, sympy.Float):
i_ref[0] += 1
return expr.func(parameters[i_ref[0]])
elif isinstance(expr, sympy.Symbol):
return expr
else:
return expr.func(
*[_replace_parameters(arg, parameters, i_ref) for arg in expr.args]
)
def replace_parameters(expr, parameters):
i_ref = [-1] # Distinctly sketchy approach to making this conversion.
return _replace_parameters(expr, parameters, i_ref)
Okay, let's get started.
We start by running the Neural ODE example. Then we extract the learnt neural vector field, and symbolically regress across this. Finally we fine-tune the resulting symbolic expression.
def main(
symbolic_dataset_size=2000,
symbolic_num_populations=100,
symbolic_population_size=20,
symbolic_migration_steps=4,
symbolic_mutation_steps=30,
symbolic_descent_steps=50,
pareto_coefficient=2,
fine_tuning_steps=500,
fine_tuning_lr=3e-3,
quantise_to=0.01,
):
#
# First obtain a neural approximation to the dynamics.
# We begin by running the previous example.
#
# Runs the Neural ODE example.
# This defines the variables `ts`, `ys`, `model`.
print("Training neural differential equation.")
%run neural_ode.ipynb
#
# Now symbolically regress across the learnt vector field, to obtain a Pareto
# frontier of symbolic equations, that trades loss against complexity of the
# equation. Select the "best" from this frontier.
#
print("Symbolically regressing across the vector field.")
vector_field = model.func.mlp # noqa: F821
dataset_size, length_size, data_size = ys.shape # noqa: F821
in_ = ys.reshape(dataset_size * length_size, data_size) # noqa: F821
in_ = in_[:symbolic_dataset_size]
out = jax.vmap(vector_field)(in_)
with tempfile.TemporaryDirectory() as tempdir:
symbolic_regressor = pysr.PySRRegressor(
niterations=symbolic_migration_steps,
ncyclesperiteration=symbolic_mutation_steps,
populations=symbolic_num_populations,
npop=symbolic_population_size,
optimizer_iterations=symbolic_descent_steps,
optimizer_nrestarts=1,
procs=1,
verbosity=0,
tempdir=tempdir,
temp_equation_file=True,
output_jax_format=True,
)
symbolic_regressor.fit(in_, out)
best_equations = symbolic_regressor.get_best()
expressions = [b.sympy_format for b in best_equations]
symbolic_fns = [
SymbolicFn(b.jax_format["callable"], b.jax_format["parameters"])
for b in best_equations
]
#
# Now the constants in this expression have been optimised for regressing across
# the neural vector field. This was good enough to obtain the symbolic expression,
# but won't quite be perfect -- some of the constants will be slightly off.
#
# To fix this we now plug our symbolic function back into the original dataset
# and apply gradient descent.
#
print("Optimising symbolic expression.")
symbolic_fn = Stack(symbolic_fns)
flat, treedef = jax.tree_util.tree_flatten(
model, is_leaf=lambda x: x is model.func.mlp # noqa: F821
)
flat = [symbolic_fn if f is model.func.mlp else f for f in flat] # noqa: F821
symbolic_model = jax.tree_util.tree_unflatten(treedef, flat)
@eqx.filter_grad
def grad_loss(symbolic_model):
vmap_model = jax.vmap(symbolic_model, in_axes=(None, 0))
pred_ys = vmap_model(ts, ys[:, 0]) # noqa: F821
return jnp.mean((ys - pred_ys) ** 2) # noqa: F821
optim = optax.adam(fine_tuning_lr)
opt_state = optim.init(eqx.filter(symbolic_model, eqx.is_inexact_array))
@eqx.filter_jit
def make_step(symbolic_model, opt_state):
grads = grad_loss(symbolic_model)
updates, opt_state = optim.update(grads, opt_state)
symbolic_model = eqx.apply_updates(symbolic_model, updates)
return symbolic_model, opt_state
for _ in range(fine_tuning_steps):
symbolic_model, opt_state = make_step(symbolic_model, opt_state)
#
# Finally we round each constant to the nearest multiple of `quantise_to`.
#
trained_expressions = []
for module, expression in zip(symbolic_model.func.mlp.modules, expressions):
expression = replace_parameters(expression, module.parameters.tolist())
expression = quantise(expression, quantise_to)
trained_expressions.append(expression)
print(f"Expressions found: {trained_expressions}")
main()