Skip to content

Convolutional¤

equinox.nn.Conv (Module) ¤

General N-dimensional convolution.

__init__(self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Sequence[int], Sequence[tuple[int, int]]] = 0, dilation: Union[int, Sequence[int]] = 1, groups: int = 1, use_bias: bool = True, padding_mode: str = 'ZEROS', dtype = None, *, key: PRNGKeyArray) ¤

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The padding of the convolution.
  • dilation: The dilation of the convolution.
  • groups: The number of input channel groups. At groups=1, all input channels contribute to all output channels. Values higher than 1 are equivalent to running groups independent Conv operations side-by-side, each having access only to in_channels // groups input channels, and concatenating the results along the output channel dimension. in_channels must be divisible by groups.
  • use_bias: Whether to add on a bias after the convolution.
  • padding_mode: One of the following strings specifying the padding values.
    • 'ZEROS' (default): pads with zeros, 1234 -> 00123400.
    • 'REFLECT': pads with the reflection on boundary, 1234 -> 32123432.
    • 'REPLICATE': pads with the replication of edge values, 1234 -> 11123444.
    • 'CIRCULAR': pads with circular values, 1234 -> 34123412.
  • dtype: The dtype to use for the weight and the bias in this layer. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

Info

All of kernel_size, stride, padding, dilation can be either an integer or a sequence of integers.

If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension.

If they are a sequence then the sequence should be of length equal to num_spatial_dims, and specify the value of each property down each spatial dimension in turn.

In addition, padding can be:

  • a sequence of 2-element tuples, each representing the padding to apply before and after each spatial dimension.
  • the string 'VALID', which is the same as zero padding.
  • one of the strings 'SAME' or 'SAME_LOWER'. This will apply padding to produce an output with the same size spatial dimensions as the input. The padding is split between the two sides equally or almost equally. In case the padding is an odd number, then the extra padding is added at the end for 'SAME' and at the beginning for 'SAME_LOWER'.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).


equinox.nn.ConvTranspose (Module) ¤

General N-dimensional transposed convolution.

__init__(self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Sequence[int], Sequence[tuple[int, int]]] = 0, output_padding: Union[int, Sequence[int]] = 0, dilation: Union[int, Sequence[int]] = 1, groups: int = 1, use_bias: bool = True, padding_mode: str = 'ZEROS', dtype = None, *, key: PRNGKeyArray) ¤

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kernel_size: The size of the transposed convolutional kernel.
  • stride: The stride used on the equivalent equinox.nn.Conv.
  • padding: The padding used on the equivalent equinox.nn.Conv.
  • output_padding: Additional padding for the output shape.
  • dilation: The spacing between kernel points.
  • groups: The number of input channel groups. At groups=1, all input channels contribute to all output channels. Values higher than 1 are equivalent to running groups independent ConvTranspose operations side-by-side, each having access only to in_channels // groups input channels, and concatenating the results along the output channel dimension. in_channels must be divisible by groups.
  • use_bias: Whether to add on a bias after the transposed convolution.
  • padding_mode: One of the following strings specifying the padding values used on the equivalent equinox.nn.Conv.
    • 'ZEROS' (default): pads with zeros, no extra connectivity.
    • 'CIRCULAR': pads with circular values, extra connectivity (see the Tip below).
  • dtype: The dtype to use for the weight and the bias in this layer. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

Info

All of kernel_size, stride, padding, dilation can be either an integer or a sequence of integers.

If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension.

If they are a sequence then the sequence should be of length equal to num_spatial_dims, and specify the value of each property down each spatial dimension in turn.

In addition, padding can be:

  • a sequence of 2-element tuples, each representing the padding to apply before and after each spatial dimension.
  • the string 'VALID', which is the same as zero padding.
  • one of the strings 'SAME' or 'SAME_LOWER'. This will apply padding to produce an output with the same size spatial dimensions as the input. The padding is split between the two sides equally or almost equally. In case the padding is an odd number, then the extra padding is added at the end for 'SAME' and at the beginning for 'SAME_LOWER'.

Tip

Transposed convolutions are often used to go in the "opposite direction" to a normal convolution. That is, from something with the shape of the output of a convolution to something with the shape of the input to a convolution. Moreover, to do so with the same "connectivity", i.e. which inputs can affect which outputs.

Relative to an equinox.nn.Conv layer, this can be accomplished by switching the values of in_channels and out_channels, whilst keeping kernel_size, stride, padding, dilation, and groups the same.

When stride > 1 then equinox.nn.Conv maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity.

  • For 'SAME' or 'SAME_LOWER' padding, it reduces the calculated input shape.
  • For other cases, it adds a little extra padding to the bottom or right edges of the input.

The extra connectivity created in 'CIRCULAR' padding is correctly taken into account. For instance, consider the equivalent equinox.nn.Conv with kernel size 3. Then:

  • Input 1234 --(zero padding)--> 012340 --(conv)--> Output abcd
  • Input 1234 --(circular padding)--> 412341 --(conv)--> Output abcd

so that a is connected with 1, 2 for zero padding, while connected with 1, 2, 4 for circular padding.

See these animations and this report for a nice reference.

FAQ

If you need to exactly transpose a convolutional layer, i.e. not just create an operation with similar inductive biases but compute the actual linear transpose of a specific CNN you can reshape the weights of the forward convolution via the following:

cnn = eqx.Conv(...)
cnn_t = eqx.ConvTranspose(...)
cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(cnn.weight,
                    axis=tuple(range(2, cnn.weight.ndim))).swapaxes(0, 1))

Warning

padding_mode='CIRCULAR' is only implemented for output_padding=0 and padding='SAME' or 'SAME_LOWER'.

__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).


equinox.nn.Conv1d (Conv) ¤

As equinox.nn.Conv with num_spatial_dims=1.


equinox.nn.Conv2d (Conv) ¤

As equinox.nn.Conv with num_spatial_dims=2.


equinox.nn.Conv3d (Conv) ¤

As equinox.nn.Conv with num_spatial_dims=3.


equinox.nn.ConvTranspose1d (ConvTranspose) ¤

As equinox.nn.ConvTranspose with num_spatial_dims=1.


equinox.nn.ConvTranspose2d (ConvTranspose) ¤

As equinox.nn.ConvTranspose with num_spatial_dims=2.


equinox.nn.ConvTranspose3d (ConvTranspose) ¤

As equinox.nn.ConvTranspose with num_spatial_dims=3.