add experimental RngBitGenerator ("RBG") PRNG #8067
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Builds on the awesome #6899 upgrade.
Not only is this an experimental API, but also because the RngBitGenerator is not guaranteed to be stable across compiler versions (or backends or shardings), let's assert that this JAX PRNG implementation may not be stable across JAX versions.
Even without that kind of stability, this PRNG is still useful because unlike the effectful RNG primitives, such as lax.rng_uniform, at a given jax/jaxlib version this RBG PRNG will still work correctly with lax.scan and jax.checkpoint (while still potentially being more performant on some platforms than JAX's standard PRNG).
Thanks to #6899, it's super easy to use, just by adapting key creation to use
prng.seed_with_impl
and then using thejax.random
API as usual:cc @levskaya @jekbradbury @zhangqiaorjc
Here's an illustration of how things can go horribly wrong with the effectful
lax.rng_uniform
:Here's RBG helping us out:
Ah, precisely the kind of reproducibility we want: different across iterations, the same across the forward and backward passes (and repeatable too!).