Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make threefry split and fold_in symmetric #13341

Merged
merged 2 commits into from
Nov 22, 2022

Conversation

froystig
Copy link
Member

@froystig froystig commented Nov 21, 2022

Namely, make it so that split(key, n)[i] equals fold_in(key, i) for any key and for 0 <= i < n.

This change affects the observed random bits for a fixed key (indirectly through splits and folds), so here we guard it behind jax.config.jax_threefry_partitionable. It's not described very well by the flag name, but it makes for a simple way to bundle together several random-bit-altering changes as part of the same upgrade cycle.

Fixes #7708, up to the jax_threefry_partitionable upgrade, with a test to confirm.

@froystig froystig self-assigned this Nov 21, 2022
@froystig froystig force-pushed the rng-split-fold-symmetry branch from f615d62 to db041ed Compare November 21, 2022 22:52
@froystig froystig requested a review from sharadmv November 21, 2022 22:53
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 21, 2022
Namely, make it so that `split(key, n)[i]` equals `fold_in(key, i)`
for any key and for `0 <= i < n`.

This change affects the observed random bits for a fixed key (indirectly
through splits and folds), so here we guard it behind
`jax.config.jax_threefry_partitionable`. It's not described very well
by the flag name, but it makes for a simple way to bundle together
several random-bit-altering changes as part of the same upgrade cycle.
@froystig froystig force-pushed the rng-split-fold-symmetry branch from db041ed to a412d27 Compare November 21, 2022 23:25
@copybara-service copybara-service bot merged commit 92ee87d into jax-ml:main Nov 22, 2022
@froystig froystig deleted the rng-split-fold-symmetry branch November 22, 2022 00:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Split keys over named batch axis
3 participants