# FAQ¤

## Optax is throwing an error.¤

Probably you're writing code that looks like

```
optim = optax.adam(learning_rate)
optim.init(model)
```

```
TypeError: zeros_like requires ndarray or scalar arguments, got <class 'jax._src.custom_derivatives.custom_jvp'> at position 0.
```

This can be fixed by doing

```
optim.init(eqx.filter(model, eqx.is_array))
```

`eqx.nn.MLP`

).
## A module saved in two places has become two independent copies.¤

Probably you're doing something like

```
class Module(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
def __init__(...):
shared_linear = eqx.nn.Linear(...)
self.linear1 = shared_linear
self.linear2 = shared_linear
```

Don't do this!

After making some gradient updates you'll find that `self.linear1`

and `self.linear2`

are now different.

Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as *trees*: that is, the same object does not appear more in the tree than once. (If it did, then it would be a *directed acyclic graph* instead.) If JAX ever encounters the same object multiple times then it will unwittingly make independent copies of the object whenever it transforms the overall PyTree.

The resolution is simple: just don't store the same object in multiple places in the PyTree.

## How do I input higher-order tensors (e.g. with batch dimensions) into my model?¤

Use `jax.vmap`

. This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions (such as batch dimensions).

For example if `x`

is an array/tensor of shape `(batch_size, input_size)`

, then the following PyTorch code:

```
import torch
linear = torch.nn.Linear(input_size, output_size)
y = linear(x)
```

is equivalent to the following Equinox code:

```
import jax
import equinox as eqx
key = jax.random.PRNGKey(seed=0)
linear = eqx.nn.Linear(input_size, output_size, key=key)
y = jax.vmap(linear)(x)
```

## TypeError: not a valid JAX type.¤

You might be getting an error like

```
TypeError: Argument '<function ...>' of type <class 'function'> is not a valid JAX type.
```

```
import jax
import equinox as eqx
def loss_fn(model, x, y):
return ((model(x) - y) ** 2).mean()
model = eqx.nn.Lambda(lambda x: x)
model = eqx.nn.MLP(2, 2, 2, 2, key=jax.random.PRNGKey(0))
x = jax.numpy.arange(2)
y = x * x
try:
jax.jit(loss_fn)(model, x, y) # error
except TypeError as e:
print(e)
eqx.filter_jit(loss_fn)(model, x, y) # ok
```

This error happens because a model, when treated as a PyTree, may have leaves that are not JAX types (such as functions). It only makes sense to trace arrays. Filtering is used to handle this.

Instead of `jax.jit`

, use `equinox.filter_jit`

. Likewise for other transformations.