Convolutional¤
equinox.nn.Conv (Module)
¤
General N-dimensional convolution.
__init__(self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Sequence[int], Sequence[tuple[int, int]]] = 0, dilation: Union[int, Sequence[int]] = 1, groups: int = 1, use_bias: bool = True, padding_mode: str = 'ZEROS', dtype = None, *, key: PRNGKeyArray)
¤
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.kernel_size
: The size of the convolutional kernel.stride
: The stride of the convolution.padding
: The padding of the convolution.dilation
: The dilation of the convolution.groups
: The number of input channel groups. Atgroups=1
, all input channels contribute to all output channels. Values higher than1
are equivalent to runninggroups
independentConv
operations side-by-side, each having access only toin_channels
//groups
input channels, and concatenating the results along the output channel dimension.in_channels
must be divisible bygroups
.use_bias
: Whether to add on a bias after the convolution.padding_mode
: One of the following strings specifying the padding values.'ZEROS'
(default): pads with zeros,1234 -> 00123400
.'REFLECT'
: pads with the reflection on boundary,1234 -> 32123432
.'REPLICATE'
: pads with the replication of edge values,1234 -> 11123444
.'CIRCULAR'
: pads with circular values,1234 -> 34123412
.
dtype
: The dtype to use for the weight and the bias in this layer. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
Info
All of kernel_size
, stride
, padding
, dilation
can be either an
integer or a sequence of integers.
If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension.
If they are a sequence then the sequence should be of length equal to
num_spatial_dims
, and specify the value of each property down each spatial
dimension in turn.
In addition, padding
can be:
- a sequence of 2-element tuples, each representing the padding to apply before and after each spatial dimension.
- the string
'VALID'
, which is the same as zero padding. - one of the strings
'SAME'
or'SAME_LOWER'
. This will apply padding to produce an output with the same size spatial dimensions as the input. The padding is split between the two sides equally or almost equally. In case the padding is an odd number, then the extra padding is added at the end for'SAME'
and at the beginning for'SAME_LOWER'
.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(in_channels, dim_1, ..., dim_N)
, whereN = num_spatial_dims
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N)
.
equinox.nn.ConvTranspose (Module)
¤
General N-dimensional transposed convolution.
__init__(self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Sequence[int], Sequence[tuple[int, int]]] = 0, output_padding: Union[int, Sequence[int]] = 0, dilation: Union[int, Sequence[int]] = 1, groups: int = 1, use_bias: bool = True, padding_mode: str = 'ZEROS', dtype = None, *, key: PRNGKeyArray)
¤
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.kernel_size
: The size of the transposed convolutional kernel.stride
: The stride used on the equivalentequinox.nn.Conv
.padding
: The padding used on the equivalentequinox.nn.Conv
.output_padding
: Additional padding for the output shape.dilation
: The spacing between kernel points.groups
: The number of input channel groups. Atgroups=1
, all input channels contribute to all output channels. Values higher than 1 are equivalent to runninggroups
independentConvTranspose
operations side-by-side, each having access only toin_channels
//groups
input channels, and concatenating the results along the output channel dimension.in_channels
must be divisible bygroups
.use_bias
: Whether to add on a bias after the transposed convolution.padding_mode
: One of the following strings specifying the padding values used on the equivalentequinox.nn.Conv
.'ZEROS'
(default): pads with zeros, no extra connectivity.'CIRCULAR'
: pads with circular values, extra connectivity (see the Tip below).
dtype
: The dtype to use for the weight and the bias in this layer. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
Info
All of kernel_size
, stride
, padding
, dilation
can be either an
integer or a sequence of integers.
If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension.
If they are a sequence then the sequence should be of length equal to
num_spatial_dims
, and specify the value of each property down each spatial
dimension in turn.
In addition, padding
can be:
- a sequence of 2-element tuples, each representing the padding to apply before and after each spatial dimension.
- the string
'VALID'
, which is the same as zero padding. - one of the strings
'SAME'
or'SAME_LOWER'
. This will apply padding to produce an output with the same size spatial dimensions as the input. The padding is split between the two sides equally or almost equally. In case the padding is an odd number, then the extra padding is added at the end for'SAME'
and at the beginning for'SAME_LOWER'
.
Tip
Transposed convolutions are often used to go in the "opposite direction" to a normal convolution. That is, from something with the shape of the output of a convolution to something with the shape of the input to a convolution. Moreover, to do so with the same "connectivity", i.e. which inputs can affect which outputs.
Relative to an equinox.nn.Conv
layer, this can be accomplished by
switching the values of in_channels
and out_channels
, whilst keeping
kernel_size
, stride
, padding
, dilation
, and groups
the same.
When stride > 1
then equinox.nn.Conv
maps multiple input shapes to
the same output shape. output_padding
is provided to resolve this
ambiguity.
- For
'SAME'
or'SAME_LOWER'
padding, it reduces the calculated input shape. - For other cases, it adds a little extra padding to the bottom or right edges of the input.
The extra connectivity created in 'CIRCULAR' padding is correctly taken into
account. For instance, consider the equivalent
equinox.nn.Conv
with kernel size 3. Then:
Input 1234 --(zero padding)--> 012340 --(conv)--> Output abcd
Input 1234 --(circular padding)--> 412341 --(conv)--> Output abcd
so that a
is connected with 1, 2
for zero padding, while connected with
1, 2, 4
for circular padding.
See these animations and this report for a nice reference.
FAQ
If you need to exactly transpose a convolutional layer, i.e. not just create an operation with similar inductive biases but compute the actual linear transpose of a specific CNN you can reshape the weights of the forward convolution via the following:
cnn = eqx.Conv(...)
cnn_t = eqx.ConvTranspose(...)
cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(cnn.weight,
axis=tuple(range(2, cnn.weight.ndim))).swapaxes(0, 1))
Warning
padding_mode='CIRCULAR'
is only implemented for output_padding=0
and
padding='SAME'
or 'SAME_LOWER'
.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(in_channels, dim_1, ..., dim_N)
, whereN = num_spatial_dims
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N)
.
equinox.nn.Conv1d (Conv)
¤
As equinox.nn.Conv
with num_spatial_dims=1
.
equinox.nn.Conv2d (Conv)
¤
As equinox.nn.Conv
with num_spatial_dims=2
.
equinox.nn.Conv3d (Conv)
¤
As equinox.nn.Conv
with num_spatial_dims=3
.
equinox.nn.ConvTranspose1d (ConvTranspose)
¤
As equinox.nn.ConvTranspose
with num_spatial_dims=1
.
equinox.nn.ConvTranspose2d (ConvTranspose)
¤
As equinox.nn.ConvTranspose
with num_spatial_dims=2
.
equinox.nn.ConvTranspose3d (ConvTranspose)
¤
As equinox.nn.ConvTranspose
with num_spatial_dims=3
.