Skip to content

Pooling¤

equinox.nn.Pool (Module) ¤

General N-dimensional downsampling over a sliding window.

__init__(self, init: Union[int, float, Array], operation: Callable[[Array, Array], Array], num_spatial_dims: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = 0, use_ceil: bool = False, **kwargs) ¤

Arguments:

  • init: The initial value for the reduction.
  • operation: The operation applied to the inputs of each window.
  • num_spatial_dims: The number of spatial dimensions.
  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The amount of padding to apply before and after each spatial dimension.
  • use_ceil: If True, then ceil is used to compute the final output shape instead of floor. For ceil, if required, extra padding is added. Defaults to False.

Info

In order for Pool to be differentiable, operation(init, x) == x needs to be true for all finite x. For further details see https://www.tensorflow.org/xla/operation_semantics#reducewindow and https://github.com/google/jax/issues/7718.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels, new_dim_1, ..., new_dim_N).


equinox.nn.AvgPool1d (Pool) ¤

One-dimensional downsample using an average over a sliding window.

__init__(self, kernel_size, stride, padding = 0, use_ceil = False, **kwargs) ¤

Arguments:

  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The amount of padding to apply before and after each spatial dimension.
  • use_ceil: If True, then ceil is used to compute the final output shape instead of floor. For ceil, if required, extra padding is added. Defaults to False.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (channels, dim).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels, new_dim).


equinox.nn.AvgPool2d (Pool) ¤

Two-dimensional downsample using an average over a sliding window.

__init__(self, kernel_size, stride, padding = 0, use_ceil = False, **kwargs) ¤

Arguments:

  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The amount of padding to apply before and after each spatial dimension.
  • use_ceil: If True, then ceil is used to compute the final output shape instead of floor. For ceil, if required, extra padding is added. Defaults to False.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (channels, dim_1, dim_2).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels, new_dim_1, new_dim_2).


equinox.nn.AvgPool3d (Pool) ¤

Three-dimensional downsample using an average over a sliding window.

__init__(self, kernel_size, stride, padding = 0, use_ceil = False, **kwargs) ¤

Arguments:

  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The amount of padding to apply before and after each spatial dimension.
  • use_ceil: If True, then ceil is used to compute the final output shape instead of floor. For ceil, if required, extra padding is added. Defaults to False.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (channels, dim_1, dim_2, dim_3).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels, new_dim_1, new_dim_2, new_dim_3).


equinox.nn.MaxPool1d (Pool) ¤

One-dimensional downsample using the maximum over a sliding window.

__init__(self, kernel_size, stride, padding = 0, use_ceil = False, **kwargs) ¤

Arguments:

  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The amount of padding to apply before and after each spatial dimension.
  • use_ceil: If True, then ceil is used to compute the final output shape instead of floor. For ceil, if required, extra padding is added. Defaults to False.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (channels, dim).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels, new_dim).


equinox.nn.MaxPool2d (Pool) ¤

Two-dimensional downsample using the maximum over a sliding window.

__init__(self, kernel_size, stride, padding = 0, use_ceil = False, **kwargs) ¤

Arguments:

  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The amount of padding to apply before and after each spatial dimension.
  • use_ceil: If True, then ceil is used to compute the final output shape instead of floor. For ceil, if required, extra padding is added. Defaults to False.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (channels, dim_1, dim_2).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels, new_dim_1, new_dim_2).


equinox.nn.MaxPool3d (Pool) ¤

Three-dimensional downsample using the maximum over a sliding window.

__init__(self, kernel_size, stride, padding = 0, use_ceil = False, **kwargs) ¤

Arguments:

  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • padding: The amount of padding to apply before and after each spatial dimension.
  • use_ceil: If True, then ceil is used to compute the final output shape instead of floor. For ceil, if required, extra padding is added. Defaults to False.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (channels, dim_1, dim_2, dim_3).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels, new_dim_1, new_dim_2, new_dim_3).


equinox.nn.AdaptivePool (Module) ¤

General N dimensional adaptive downsampling to a target shape.

__init__(self, target_shape: Union[int, Sequence[int]], num_spatial_dims: int, operation: Callable, **kwargs) ¤

Arguments:

  • target_shape: The target output shape.
  • num_spatial_dims: The number of spatial dimensions.
  • operation: The operation applied for downsample.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (channels, dim_1, dim_2, ... ).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels,) + target_shape.


equinox.nn.AdaptiveAvgPool1d (AdaptivePool) ¤

Adaptive one-dimensional downsampling using average for the target shape.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array inherited ¤
__init__(self, target_shape: Union[int, Sequence[int]], **kwargs) ¤

Arguments:

  • target_shape: The target output shape.

equinox.nn.AdaptiveAvgPool2d (AdaptivePool) ¤

Adaptive two-dimensional downsampling using average for the target shape.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array inherited ¤
__init__(self, target_shape: Union[int, Sequence[int]], **kwargs) ¤

Arguments:

  • target_shape: The target output shape.

equinox.nn.AdaptiveAvgPool3d (AdaptivePool) ¤

Adaptive three-dimensional downsampling using average for the target shape.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array inherited ¤
__init__(self, target_shape: Union[int, Sequence[int]], **kwargs) ¤

Arguments:

  • target_shape: The target output shape.

equinox.nn.AdaptiveMaxPool1d (AdaptivePool) ¤

Adaptive one-dimensional downsampling using maximum for the target shape.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array inherited ¤
__init__(self, target_shape: Union[int, Sequence[int]], **kwargs) ¤

Arguments:

  • target_shape: The target output shape.

equinox.nn.AdaptiveMaxPool2d (AdaptivePool) ¤

Adaptive two-dimensional downsampling using maximum for the target shape.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array inherited ¤
__init__(self, target_shape: Union[int, Sequence[int]], **kwargs) ¤

Arguments:

  • target_shape: The target output shape.

equinox.nn.AdaptiveMaxPool3d (AdaptivePool) ¤

Adaptive three-dimensional downsampling using maximum for the target shape.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array inherited ¤
__init__(self, target_shape: Union[int, Sequence[int]], **kwargs) ¤

Arguments:

  • target_shape: The target output shape.