Randomness¤
Perfectly controlling all randomness offers a great way to debug failing tests, or reproduce errors in remote machines. For this reason we have structional.PRNGKey.
This only really works if you control your whole tech stack (if you use a library with its own random behaviour then this can't control that), but as a practical matter that tends to be true throughout my work.
structional.PRNGKey
¤
An immutable PRNG key for sampling of randomness. Available methods:
split,fold_in,duplicate: these manipulate keys themselves.splitis by far the most common one to use; the others are used rarely.permutation,integers, ...: all the distributions available onnp.random.Generator.
This is an 'affine type', meaning that it can only be used once. This is important to prevent accidental key re-use, which would produce the same randomness.
versus numpy.random.{Generator, SeedSequence} or setting a seed?
PRNGKey is immutable and
affinely-typed (it
can only be used once). This makes it very easy to reason about.
In contrast:
-
numpy.random.Generatorandnumpy.random.SeedSequenceare mutable:def foo(y, rng: numpy.random.Generator): return rng.do_something(y) x = foo(y, your_generator) # `your_generator` has been mutated here!which prevents any useful attempt at reproducing a function call based on its inputs.
-
setting a seed makes it completely unclear whether a function is random or not. It's hard to write reproducible code if this is all you do. Plus, every framework seems to have their own random state you need to remember to set.
split(num_splits: int) -> tuple[structional.PRNGKey, ...]
¤
Consumes this key, and returns num_splits many new keys, each a
deterministic function of the original key, but with statistically independent
randomness to each other and to the original key.
Example
This function is used frequently at the start of a function call, splitting the key into however many keys are necessary throughout the function call.
def train_model(..., key: PRNGKey):
mkey, dkey = key.split(num_splits=2)
model = Model(..., key=mkey)
dataloader = DataLoader(..., key=dkey):
...
fold_in(value: int) -> structional.PRNGKey
¤
Consumes this key (with the exception of any other .fold_in calls), and
returns a new key, which is a deterministic function of the original key and of
the input value, but with statistically independent randomness to the original
key.
Example
def solve_stochastic_differential_equation(y0, key: PRNGKey):
y = y0
dt = 0.1
for i in range(num_steps):
dw = dt * key.fold_in(i).normal(...)
y = y + dt * drift(y) + dw * diffusion(y)
return y
duplicate() -> structional.PRNGKey
¤
Returns a copy of this key. Does not consume the original key. (This should be used very rarely.)
integers(low, high=None, size=None, dtype=numpy.int64, endpoint=False)
¤
random(size=None, dtype=numpy.float64, out=None)
¤
choice(a, size=None, replace=True, p=None, axis=0, shuffle=True)
¤
bytes(length)
¤
shuffle(x, axis=0)
¤
permutation(x, axis=0)
¤
permuted(x, *, axis=None, out=None)
¤
beta(a, b, size=None)
¤
binomial(n, p, size=None)
¤
chisquare(df, size=None)
¤
dirichlet(alpha, size=None)
¤
exponential(scale=1.0, size=None)
¤
f(dfnum, dfden, size=None)
¤
gamma(shape, scale=1.0, size=None)
¤
geometric(p, size=None)
¤
gumbel(loc=0.0, scale=1.0, size=None)
¤
hypergeometric(ngood, nbad, nsample, size=None)
¤
laplace(loc=0.0, scale=1.0, size=None)
¤
logistic(loc=0.0, scale=1.0, size=None)
¤
lognormal(mean=0.0, sigma=1.0, size=None)
¤
logseries(p, size=None)
¤
multinomial(n, pvals, size=None)
¤
multivariate_hypergeometric(colors, nsample, size=None, method='marginals')
¤
multivariate_normal(mean, cov, size=None, check_valid='warn', tol=1e-08, *, method='svd')
¤
negative_binomial(n, p, size=None)
¤
noncentral_chisquare(df, nonc, size=None)
¤
noncentral_f(dfnum, dfden, nonc, size=None)
¤
normal(loc=0.0, scale=1.0, size=None)
¤
pareto(a, size=None)
¤
poisson(lam=1.0, size=None)
¤
power(a, size=None)
¤
rayleigh(scale=1.0, size=None)
¤
standard_cauchy(size=None)
¤
standard_exponential(size=None, dtype=numpy.float64, method='zig', out=None)
¤
standard_gamma(shape, size=None, dtype=numpy.float64, out=None)
¤
standard_normal(size=None, dtype=numpy.float64, out=None)
¤
standard_t(df, size=None)
¤
triangular(left, mode, right, size=None)
¤
uniform(low=0.0, high=1.0, size=None)
¤
vonmises(mu, kappa, size=None)
¤
wald(mean, scale, size=None)
¤
weibull(a, size=None)
¤
zipf(a, size=None)
¤
structional.prngkey_fixture() -> structional.PRNGKey
¤
Designed for use as a fixture in tests: generates a random PRNGKey.
Do not use this in any other context; the random seed generation gives deliberate non-determinism.
When writing a test, you should consider whether you want it to be deterministic (to specifically check a particular edge case), or random (to increase coverage of your system across multiple runs). If writing a random test then use of this class is best practice, as pytest will display the randomly generated seed, allowing for any test failures to be debugged reproducibly.
The seed used by this fixture will be set to the value of the
STRUCTIONAL_PRNGKEY_SEED environment variable if that is available (i.e. when
reproducing a failing test), else a random number will be used.
Example
# tests/conftest.py
from structional import prngkey_fixture as prngkey
# tests/some_test.py
def test_foo(prngkey: PRNGKey):
...