# FAQ¤

## Optax is throwing an error.¤

Probably you're writing code that looks like

```
optim = optax.adam(learning_rate)
optim.init(model)
```

```
TypeError: zeros_like requires ndarray or scalar arguments, got <class 'jax._src.custom_derivatives.custom_jvp'> at position 0.
```

This can be fixed by doing

```
optim.init(eqx.filter(model, eqx.is_array))
```

`eqx.nn.MLP`

).
## A module saved in two places has become two independent copies.¤

Probably you're doing something like

```
class Module(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
def __init__(...):
shared_linear = eqx.nn.Linear(...)
self.linear1 = shared_linear
self.linear2 = shared_linear
```

Don't do this!

After making some gradient updates you'll find that `self.linear1`

and `self.linear2`

are now different.

Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as *trees*: that is, the same object does not appear more in the tree than once. (If it did, then it would be a *directed acyclic graph* instead.) If JAX ever encounters the same object multiple times then it will unwittingly make independent copies of the object whenever it transforms the overall PyTree.

The resolution is simple: just don't store the same object in multiple places in the PyTree.

## How do I input higher-order tensors (e.g. with batch dimensions) into my model?¤

Use `jax.vmap`

. This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions (such as batch dimensions).

For example if `x`

is an array/tensor of shape `(batch_size, input_size)`

, then the following PyTorch code:

```
import torch
linear = torch.nn.Linear(input_size, output_size)
y = linear(x)
```

is equivalent to the following Equinox code:

```
import jax
import equinox as eqx
key = jax.random.PRNGKey(seed=0)
linear = eqx.nn.Linear(input_size, output_size, key=key)
y = jax.vmap(linear)(x)
```