Dropout¤
equinox.nn.Dropout (Module)
¤
Applies dropout.
Note that this layer behaves differently during training and inference. During
training then dropout is randomly applied; during inference this layer does nothing.
Whether the model is in training or inference mode should be toggled using
equinox.nn.inference_mode
.
__init__(self, p: float = 0.5, inference: bool = False, *, deterministic: Optional[bool] = None)
¤
Arguments:
p
: The fraction of entries to set to zero. (On average.)inference
: Whether to actually apply dropout at all. IfTrue
then dropout is not applied. IfFalse
then dropout is applied. This may be toggled withequinox.nn.inference_mode
or overridden duringequinox.nn.Dropout.__call__
.deterministic
: Deprecated alternative toinference
.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None, deterministic: Optional[bool] = None) -> Array
¤
Arguments:
x
: An any-dimensional JAX array to dropout.key
: Ajax.random.PRNGKey
used to provide randomness for calculating which elements to dropout. (Keyword only argument.)inference
: As perequinox.nn.Dropout.__init__
. IfTrue
orFalse
then it will take priority overself.inference
. IfNone
then the value fromself.inference
will be used.deterministic
: Deprecated alternative toinference
.