-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Design suggestion: consider adding jax.random.Generator #2294
Comments
It might also help to have |
Looks like numpy already has splittable RNGs: https://numpy.org/devdocs/reference/random/parallel.html Their design looks pretty outstanding, so it might be best to copy it. You might not need to have split at all in Jax, if your RNG accepts a seed provided by a SeedSequence. |
Also, rkern has a very insightful comment about the RNG used in Jax. I don't understand all of it, but it seems fascinating. |
I agree that a class based interface for random number generation in JAX would be convenient. One reason why this could be a little tricky (vs the current function based approach) is that the class would need to be closed under JAX's transformation rules like batching. I don't think we want stateful RNGs in JAX (these make functional transformations very hard), which rules out copying NumPy's API exactly. But maybe there's some aspect of the |
@shoyer Thanks for the clear explanation. That sounds really unfortunate—not just for random number generators, but for JAX objects in general. It feels like we're going back 25 years to when C code would all be structures, and you'd have functions that accepted the structures as the first parameters. Objects are superior to this pattern for many reasons including inheritance and polymorphism. It would be nice to keep those properties. I'm not sure if it's possible, but I've opened an issue to discuss making that possible: #2328. |
Classes work totally fine in JAX as long as you don't mutate state. Even your exact code example runs, e.g., see this notebook for a working example: The reason why classes are awkward is that JAX's transforms like keys = jax.vmap(random.PRNGKey)(jax.numpy.arange(3))
samples = jax.vmap(random.uniform)(keys) With a class, you need to write something like:
Which is fine, I guess, but not very convenient. Intuitively, one might expect that batching the creation of a class would pre-batch all of its methods, but definitely doesn't work right now ( |
Just adding a comment in this thread, even though I gave a code example in the linked issue #2328. I think a better approach would be to make the classes all optionally vectorized from the start, so the RNG class would accept an optional shape parameter too. Then you don't need to vmap it, but rather do something like this:
Would having this jax shape in all mutable objects allow us to do this? Working with values and non-polymorphic functions seems unnecessarily frustrating compared with objects and methods. |
Thanks for the ideas! IIUC the main proposal is to have a distinguished PRNG key type, e.g. for mypy type checking. That sounds like a good idea! It wouldn't require any API additions or changes. In fact, in the side-effects branch we prototyped exactly that. We plan to merge something along those lines in the next couple months. There are also ideas around API changes or internal refactorings, but I find those less compelling:
Is my understanding of the main proposal (a distinguished PRNG key type, so keys aren't DeviceArrays but still work as they do now with JAX transformations) correct? If so, I propose we tweak the title of the issue, and leave it open until we land something along those lines. WDYT? |
The full proposal is to have a class type for PRNGs. You're right that one benefit is type-checking. Another benefit is abstraction and the ability to call methods on objects, which enables polymorphism. I think you may be underestimating he benefits of OO in writing clean code, but okay, let's cross that bridge when we get there. I may soon have a more motivating example.
Awesome! I don't totally understand it yet, but looking forward to it.
Fair enough.
Right, the only benefit is polymorphism.
Please, go ahead and tweak the title. I am waiting to see if in the development of my project I run into any walls caused by not have polymorphism in JAX. If so, my plan was to flesh out a clearer proposal at that time. I agree with your philosophy of "crossing that bridge when we get there". |
@NeilGirdhar the value returned by |
@shoyer After some more thought, you're right. I should avoid using static arguments. Sorry for the noise. |
(Closing because Haiku appears to have a more elegant solution.) |
Edited in August 2020:
Jax's splittable RNGs are a really great idea! However, it might be a good idea to expose the key as a class so that
For example: tjax.generator.
The text was updated successfully, but these errors were encountered: