Skip to content


equinox.nn.Dropout (Module) ¤

Applies dropout.

__init__(self, p: float = 0.5, inference: bool = False, *, deterministic: Optional[bool] = None) ¤


  • p: The fraction of entries to set to zero. (On average.)
  • inference: Whether to actually apply dropout at all. If True then dropout is not applied. If False then dropout is applied. This may be toggled with equinox.tree_inference or overridden during equinox.nn.Dropout.__call__.
  • deterministic: Deprecated alternative to inference.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None, inference: Optional[bool] = None, deterministic: Optional[bool] = None) -> Array ¤


  • x: An any-dimensional JAX array to dropout.
  • key: A jax.random.PRNGKey used to provide randomness for calculating which elements to dropout. (Keyword only argument.)
  • inference: As per equinox.nn.Dropout.__init__. If True or False then it will take priority over self.inference. If None then the value from self.inference will be used.
  • deterministic: Deprecated alternative to inference.