# Convolutional¤

####
```
equinox.nn.Conv (Module)
```

¤

General N-dimensional convolution.

#####
`__init__(self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Sequence[int], Sequence[tuple[int, int]]] = 0, dilation: Union[int, Sequence[int]] = 1, groups: int = 1, use_bias: bool = True, padding_mode: str = 'ZEROS', dtype = None, *, key: PRNGKeyArray)`

¤

**Arguments:**

`num_spatial_dims`

: The number of spatial dimensions. For example traditional convolutions for image processing have this set to`2`

.`in_channels`

: The number of input channels.`out_channels`

: The number of output channels.`kernel_size`

: The size of the convolutional kernel.`stride`

: The stride of the convolution.`padding`

: The padding of the convolution.`dilation`

: The dilation of the convolution.`groups`

: The number of input channel groups. At`groups=1`

, all input channels contribute to all output channels. Values higher than`1`

are equivalent to running`groups`

independent`Conv`

operations side-by-side, each having access only to`in_channels`

//`groups`

input channels, and concatenating the results along the output channel dimension.`in_channels`

must be divisible by`groups`

.`use_bias`

: Whether to add on a bias after the convolution.`padding_mode`

: One of the following strings specifying the padding values.`'ZEROS'`

(default): pads with zeros,`1234 -> 00123400`

.`'REFLECT'`

: pads with the reflection on boundary,`1234 -> 32123432`

.`'REPLICATE'`

: pads with the replication of edge values,`1234 -> 11123444`

.`'CIRCULAR'`

: pads with circular values,`1234 -> 34123412`

.

`dtype`

: The dtype to use for the weight and the bias in this layer. Defaults to either`jax.numpy.float32`

or`jax.numpy.float64`

depending on whether JAX is in 64-bit mode.`key`

: A`jax.random.PRNGKey`

used to provide randomness for parameter initialisation. (Keyword only argument.)

Info

All of `kernel_size`

, `stride`

, `padding`

, `dilation`

can be either an
integer or a sequence of integers.

If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension.

If they are a sequence then the sequence should be of length equal to
`num_spatial_dims`

, and specify the value of each property down each spatial
dimension in turn.

In addition, `padding`

can be:

- a sequence of 2-element tuples, each representing the padding to apply before and after each spatial dimension.
- the string
`'VALID'`

, which is the same as zero padding. - one of the strings
`'SAME'`

or`'SAME_LOWER'`

. This will apply padding to produce an output with the same size spatial dimensions as the input. The padding is split between the two sides equally or almost equally. In case the padding is an odd number, then the extra padding is added at the end for`'SAME'`

and at the beginning for`'SAME_LOWER'`

.

#####
`__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array`

¤

**Arguments:**

`x`

: The input. Should be a JAX array of shape`(in_channels, dim_1, ..., dim_N)`

, where`N = num_spatial_dims`

.`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

**Returns:**

A JAX array of shape `(out_channels, new_dim_1, ..., new_dim_N)`

.

####
```
equinox.nn.ConvTranspose (Module)
```

¤

General N-dimensional transposed convolution.

#####
`__init__(self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Sequence[int], Sequence[tuple[int, int]]] = 0, output_padding: Union[int, Sequence[int]] = 0, dilation: Union[int, Sequence[int]] = 1, groups: int = 1, use_bias: bool = True, padding_mode: str = 'ZEROS', dtype = None, *, key: PRNGKeyArray)`

¤

**Arguments:**

`num_spatial_dims`

: The number of spatial dimensions. For example traditional convolutions for image processing have this set to`2`

.`in_channels`

: The number of input channels.`out_channels`

: The number of output channels.`kernel_size`

: The size of the transposed convolutional kernel.`stride`

: The stride used on the equivalent`equinox.nn.Conv`

.`padding`

: The padding used on the equivalent`equinox.nn.Conv`

.`output_padding`

: Additional padding for the output shape.`dilation`

: The spacing between kernel points.`groups`

: The number of input channel groups. At`groups=1`

, all input channels contribute to all output channels. Values higher than 1 are equivalent to running`groups`

independent`ConvTranspose`

operations side-by-side, each having access only to`in_channels`

//`groups`

input channels, and concatenating the results along the output channel dimension.`in_channels`

must be divisible by`groups`

.`use_bias`

: Whether to add on a bias after the transposed convolution.`padding_mode`

: One of the following strings specifying the padding values used on the equivalent`equinox.nn.Conv`

.`'ZEROS'`

(default): pads with zeros, no extra connectivity.`'CIRCULAR'`

: pads with circular values, extra connectivity (see the Tip below).

`dtype`

: The dtype to use for the weight and the bias in this layer. Defaults to either`jax.numpy.float32`

or`jax.numpy.float64`

depending on whether JAX is in 64-bit mode.`key`

: A`jax.random.PRNGKey`

used to provide randomness for parameter initialisation. (Keyword only argument.)

Info

All of `kernel_size`

, `stride`

, `padding`

, `dilation`

can be either an
integer or a sequence of integers.

If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension.

If they are a sequence then the sequence should be of length equal to
`num_spatial_dims`

, and specify the value of each property down each spatial
dimension in turn.

In addition, `padding`

can be:

- a sequence of 2-element tuples, each representing the padding to apply before and after each spatial dimension.
- the string
`'VALID'`

, which is the same as zero padding. - one of the strings
`'SAME'`

or`'SAME_LOWER'`

. This will apply padding to produce an output with the same size spatial dimensions as the input. The padding is split between the two sides equally or almost equally. In case the padding is an odd number, then the extra padding is added at the end for`'SAME'`

and at the beginning for`'SAME_LOWER'`

.

Tip

Transposed convolutions are often used to go in the "opposite direction" to a normal convolution. That is, from something with the shape of the output of a convolution to something with the shape of the input to a convolution. Moreover, to do so with the same "connectivity", i.e. which inputs can affect which outputs.

Relative to an `equinox.nn.Conv`

layer, this can be accomplished by
switching the values of `in_channels`

and `out_channels`

, whilst keeping
`kernel_size`

, `stride`

, `padding`

, `dilation`

, and `groups`

the same.

When `stride > 1`

then `equinox.nn.Conv`

maps multiple input shapes to
the same output shape. `output_padding`

is provided to resolve this
ambiguity.

- For
`'SAME'`

or`'SAME_LOWER'`

padding, it reduces the calculated input shape. - For other cases, it adds a little extra padding to the bottom or right edges of the input.

The extra connectivity created in 'CIRCULAR' padding is correctly taken into
account. For instance, consider the equivalent
`equinox.nn.Conv`

with kernel size 3. Then:

`Input 1234 --(zero padding)--> 012340 --(conv)--> Output abcd`

`Input 1234 --(circular padding)--> 412341 --(conv)--> Output abcd`

so that `a`

is connected with `1, 2`

for zero padding, while connected with
`1, 2, 4`

for circular padding.

See these animations and this report for a nice reference.

FAQ

If you need to exactly transpose a convolutional layer, i.e. not just create an operation with similar inductive biases but compute the actual linear transpose of a specific CNN you can reshape the weights of the forward convolution via the following:

```
cnn = eqx.Conv(...)
cnn_t = eqx.ConvTranspose(...)
cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(cnn.weight,
axis=tuple(range(2, cnn.weight.ndim))).swapaxes(0, 1))
```

Warning

`padding_mode='CIRCULAR'`

is only implemented for `output_padding=0`

and
`padding='SAME'`

or `'SAME_LOWER'`

.

#####
`__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array`

¤

**Arguments:**

`x`

: The input. Should be a JAX array of shape`(in_channels, dim_1, ..., dim_N)`

, where`N = num_spatial_dims`

.`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

**Returns:**

A JAX array of shape `(out_channels, new_dim_1, ..., new_dim_N)`

.

####
```
equinox.nn.Conv1d (Conv)
```

¤

As `equinox.nn.Conv`

with `num_spatial_dims=1`

.

####
```
equinox.nn.Conv2d (Conv)
```

¤

As `equinox.nn.Conv`

with `num_spatial_dims=2`

.

####
```
equinox.nn.Conv3d (Conv)
```

¤

As `equinox.nn.Conv`

with `num_spatial_dims=3`

.

####
```
equinox.nn.ConvTranspose1d (ConvTranspose)
```

¤

As `equinox.nn.ConvTranspose`

with `num_spatial_dims=1`

.

####
```
equinox.nn.ConvTranspose2d (ConvTranspose)
```

¤

As `equinox.nn.ConvTranspose`

with `num_spatial_dims=2`

.

####
```
equinox.nn.ConvTranspose3d (ConvTranspose)
```

¤

As `equinox.nn.ConvTranspose`

with `num_spatial_dims=3`

.