Pooling¤
equinox.nn.Pool(equinox.Module)
¤
General N-dimensional downsampling over a sliding window.
__init__(init: int | float | Array, operation: Callable[[Array, Array], Array], num_spatial_dims: int, kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, use_ceil: bool = False)
¤
Arguments:
init
: The initial value for the reduction.operation
: The operation applied to the inputs of each window.num_spatial_dims
: The number of spatial dimensions.kernel_size
: The size of the convolutional kernel.stride
: The stride of the convolution.padding
: The amount of padding to apply before and after each spatial dimension.use_ceil
: IfTrue
, thenceil
is used to compute the final output shape instead offloor
. Forceil
, if required, extra padding is added. Defaults toFalse
.
Info
In order for Pool
to be differentiable, operation(init, x) == x
needs to
be true for all finite x
. For further details see
https://www.tensorflow.org/xla/operation_semantics#reducewindow
and https://github.com/google/jax/issues/7718.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, ..., dim_N)
, whereN = num_spatial_dims
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels, new_dim_1, ..., new_dim_N)
.
equinox.nn.AvgPool1d(equinox.nn.Pool)
¤
One-dimensional downsample using an average over a sliding window.
__init__(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, use_ceil: bool = False)
¤
Arguments:
kernel_size
: The size of the convolutional kernel.stride
: The stride of the convolution.padding
: The amount of padding to apply before and after each spatial dimension.use_ceil
: IfTrue
, thenceil
is used to compute the final output shape instead offloor
. Forceil
, if required, extra padding is added. Defaults toFalse
.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels, new_dim)
.
equinox.nn.AvgPool2d(equinox.nn.Pool)
¤
Two-dimensional downsample using an average over a sliding window.
__init__(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, use_ceil: bool = False)
¤
Arguments:
kernel_size
: The size of the convolutional kernel.stride
: The stride of the convolution.padding
: The amount of padding to apply before and after each spatial dimension.use_ceil
: IfTrue
, thenceil
is used to compute the final output shape instead offloor
. Forceil
, if required, extra padding is added. Defaults toFalse
.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels, new_dim_1, new_dim_2)
.
equinox.nn.AvgPool3d(equinox.nn.Pool)
¤
Three-dimensional downsample using an average over a sliding window.
__init__(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, use_ceil: bool = False)
¤
Arguments:
kernel_size
: The size of the convolutional kernel.stride
: The stride of the convolution.padding
: The amount of padding to apply before and after each spatial dimension.use_ceil
: IfTrue
, thenceil
is used to compute the final output shape instead offloor
. Forceil
, if required, extra padding is added. Defaults toFalse
.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, dim_3)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels, new_dim_1, new_dim_2, new_dim_3)
.
equinox.nn.MaxPool1d(equinox.nn.Pool)
¤
One-dimensional downsample using the maximum over a sliding window.
__init__(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, use_ceil: bool = False)
¤
Arguments:
kernel_size
: The size of the convolutional kernel.stride
: The stride of the convolution.padding
: The amount of padding to apply before and after each spatial dimension.use_ceil
: IfTrue
, thenceil
is used to compute the final output shape instead offloor
. Forceil
, if required, extra padding is added. Defaults toFalse
.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels, new_dim)
.
equinox.nn.MaxPool2d(equinox.nn.Pool)
¤
Two-dimensional downsample using the maximum over a sliding window.
__init__(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, use_ceil: bool = False)
¤
Arguments:
kernel_size
: The size of the convolutional kernel.stride
: The stride of the convolution.padding
: The amount of padding to apply before and after each spatial dimension.use_ceil
: IfTrue
, thenceil
is used to compute the final output shape instead offloor
. Forceil
, if required, extra padding is added. Defaults toFalse
.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels, new_dim_1, new_dim_2)
.
equinox.nn.MaxPool3d(equinox.nn.Pool)
¤
Three-dimensional downsample using the maximum over a sliding window.
__init__(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, use_ceil: bool = False)
¤
Arguments:
kernel_size
: The size of the convolutional kernel.stride
: The stride of the convolution.padding
: The amount of padding to apply before and after each spatial dimension.use_ceil
: IfTrue
, thenceil
is used to compute the final output shape instead offloor
. Forceil
, if required, extra padding is added. Defaults toFalse
.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, dim_3)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels, new_dim_1, new_dim_2, new_dim_3)
.
equinox.nn.AdaptivePool(equinox.Module)
¤
General N dimensional adaptive downsampling to a target shape.
__init__(target_shape: int | Sequence[int], num_spatial_dims: int, operation: Callable)
¤
Arguments:
target_shape
: The target output shape.num_spatial_dims
: The number of spatial dimensions.operation
: The operation applied for downsample.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, ... )
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels,) + target_shape
.
equinox.nn.AdaptiveAvgPool1d(equinox.nn.AdaptivePool)
¤
Adaptive one-dimensional downsampling using average for the target shape.
__init__(target_shape: int | Sequence[int])
¤
Arguments:
target_shape
: The target output shape.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, ... )
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels,) + target_shape
.
equinox.nn.AdaptiveAvgPool2d(equinox.nn.AdaptivePool)
¤
Adaptive two-dimensional downsampling using average for the target shape.
__init__(target_shape: int | Sequence[int])
¤
Arguments:
target_shape
: The target output shape.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, ... )
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels,) + target_shape
.
equinox.nn.AdaptiveAvgPool3d(equinox.nn.AdaptivePool)
¤
Adaptive three-dimensional downsampling using average for the target shape.
__init__(target_shape: int | Sequence[int])
¤
Arguments:
target_shape
: The target output shape.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, ... )
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels,) + target_shape
.
equinox.nn.AdaptiveMaxPool1d(equinox.nn.AdaptivePool)
¤
Adaptive one-dimensional downsampling using maximum for the target shape.
__init__(target_shape: int | Sequence[int])
¤
Arguments:
target_shape
: The target output shape.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, ... )
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels,) + target_shape
.
equinox.nn.AdaptiveMaxPool2d(equinox.nn.AdaptivePool)
¤
Adaptive two-dimensional downsampling using maximum for the target shape.
__init__(target_shape: int | Sequence[int])
¤
Arguments:
target_shape
: The target output shape.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, ... )
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels,) + target_shape
.
equinox.nn.AdaptiveMaxPool3d(equinox.nn.AdaptivePool)
¤
Adaptive three-dimensional downsampling using maximum for the target shape.
__init__(target_shape: int | Sequence[int])
¤
Arguments:
target_shape
: The target output shape.
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(channels, dim_1, dim_2, ... )
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels,) + target_shape
.