Skip to content

Convolutional¤

equinox.nn.Conv (Module) ¤

General N-dimensional convolution.

__init__(self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = 0, dilation: Union[int, Sequence[int]] = 1, groups: int = 1, use_bias: bool = True, *, key: jax.random.PRNGKey, **kwargs) ¤

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The amount of padding to apply before and after each spatial dimension.
  • dilation: The dilation of the convolution.
  • groups: The number of input channel groups. At groups=1, all input channels contribute to all output channels. Values higher than 1 are equivalent to running groups independent Conv operations side-by-side, each having access only to in_channels // groups input channels, and concatenating the results along the output channel dimension. in_channels must be divisible by groups.
  • use_bias: Whether to add on a bias after the convolution.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

Info

All of kernel_size, stride, padding, dilation can be either an integer or a sequence of integers. If they are a sequence then the sequence should be of length equal to num_spatial_dims, and specify the value of each property down each spatial dimension in turn.

If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension.

padding can alternatively also be a sequence of 2-element tuples, each representing the padding to apply before and after each spatial dimension.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).


equinox.nn.ConvTranspose (Module) ¤

General N-dimensional transposed convolution.

__init__(self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = 0, output_padding: Union[int, Sequence[int]] = 0, dilation: Union[int, Sequence[int]] = 1, groups: int = 1, use_bias: bool = True, *, key: jax.random.PRNGKey, **kwargs) ¤

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kernel_size: The size of the transposed convolutional kernel.
  • stride: The stride used on the equivalent equinox.nn.Conv.
  • padding: The amount of padding used on the equivalent equinox.nn.Conv.
  • output_padding: Additional padding for the output shape.
  • dilation: The spacing between kernel points.
  • groups: The number of input channel groups. At groups=1, all input channels contribute to all output channels. Values higher than 1 are equivalent to running groups independent ConvTranspose operations side-by-side, each having access only to in_channels // groups input channels, and concatenating the results along the output channel dimension. in_channels must be divisible by groups.
  • use_bias: Whether to add on a bias after the transposed convolution.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

Info

All of kernel_size, stride, padding, output_padding, dilation can be either an integer or a sequence of integers. If they are a sequence then the sequence should be of length equal to num_spatial_dims, and specify the value of each property down each spatial dimension in turn.

If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension.

padding can alternatively also be a sequence of 2-element tuples, each representing the padding to apply before and after each spatial dimension.

Tip

Transposed convolutions are often used to go in the "opposite direction" to a normal convolution. That is, from something with the shape of the output of a convolution to something with the shape of the input to a convolution. Moreover, to do so with the same "connectivity", i.e. which inputs can affect which outputs.

Relative to an equinox.nn.Conv layer, this can be accomplished by switching the values of in_channels and out_channels, whilst keeping kernel_size, stride, padding, dilation, and groups the same.

When stride > 1 then equinox.nn.Conv maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity, by adding a little extra padding to just the bottom/right edges of the input.

See these animations and this report for a nice reference.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).


equinox.nn.Conv1d (Conv) ¤

As equinox.nn.Conv with num_spatial_dims=1.


equinox.nn.Conv2d (Conv) ¤

As equinox.nn.Conv with num_spatial_dims=2.


equinox.nn.Conv3d (Conv) ¤

As equinox.nn.Conv with num_spatial_dims=3.


equinox.nn.ConvTranspose1d (ConvTranspose) ¤

As equinox.nn.ConvTranspose with num_spatial_dims=1.


equinox.nn.ConvTranspose2d (ConvTranspose) ¤

As equinox.nn.ConvTranspose with num_spatial_dims=2.


equinox.nn.ConvTranspose3d (ConvTranspose) ¤

As equinox.nn.ConvTranspose with num_spatial_dims=3.