# Getting started¤

Equinox is a JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees.

In doing so:

• We get a PyTorch-like API...
• ...that's fully compatible with native JAX transformations...
• ...with no new concepts you have to learn. (It's all just PyTrees.)

The elegance of Equinox is its selling point in a world that already has Haiku, Flax and so on.

(In other words, why should you care? Because Equinox is really simple to learn, and really simple to use.)

## Installation¤

pip install equinox


Requires Python 3.7+ and JAX 0.3.4+.

## Quick example¤

Models are defined using PyTorch-like syntax:

import equinox as eqx
import jax

class Linear(eqx.Module):
weight: jax.numpy.ndarray
bias: jax.numpy.ndarray

def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))

def __call__(self, x):
return self.weight @ x + self.bias


and fully compatible with normal JAX operations:

@jax.jit
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)


Finally, there's no magic behind the scenes. All eqx.Module does is register your class as a PyTree. From that point onwards, JAX already knows how to work with PyTrees.

## Next steps¤

If this quick start has got you interested, then have a read of All of Equinox, which introduces you to basically everything in Equinox. (Doesn't take very long! Equinox is simple because everything is a PyTree.)

## Citation¤

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{kidger2021equinox,
author={Patrick Kidger and Cristian Garcia},
title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
year={2021},
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}


(Also consider starring the project on GitHub.)