Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Design suggestion: consider adding jax.random.Generator #2294

Closed
NeilGirdhar opened this issue Feb 23, 2020 · 12 comments
Closed

Design suggestion: consider adding jax.random.Generator #2294

NeilGirdhar opened this issue Feb 23, 2020 · 12 comments
Assignees

Comments

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Feb 23, 2020

Edited in August 2020:

Jax's splittable RNGs are a really great idea! However, it might be a good idea to expose the key as a class so that

  • if Jax ever generalizes to other splittable RNGs, the change can be made transparently,
  • the type annotation linters (like mypy) can verify key usage (currently DeviceArray is used for other things too),
  • to make code a bit less unnecessarily verbose, and
  • to be closer to the elegant numpy interface.

For example: tjax.generator.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Feb 27, 2020

It might also help to have jnp.random.default_rng(seed) like numpy

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Feb 27, 2020

Looks like numpy already has splittable RNGs: https://numpy.org/devdocs/reference/random/parallel.html

Their design looks pretty outstanding, so it might be best to copy it. You might not need to have split at all in Jax, if your RNG accepts a seed provided by a SeedSequence.

@NeilGirdhar
Copy link
Contributor Author

Also, rkern has a very insightful comment about the RNG used in Jax. I don't understand all of it, but it seems fascinating.

numpy/numpy#15656 (comment)

@shoyer
Copy link
Collaborator

shoyer commented Feb 27, 2020

I agree that a class based interface for random number generation in JAX would be convenient. One reason why this could be a little tricky (vs the current function based approach) is that the class would need to be closed under JAX's transformation rules like batching.

I don't think we want stateful RNGs in JAX (these make functional transformations very hard), which rules out copying NumPy's API exactly. But maybe there's some aspect of the SeedSequence design or NumPy's new RNGs that could be helpful for us.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Feb 28, 2020

@shoyer Thanks for the clear explanation. That sounds really unfortunate—not just for random number generators, but for JAX objects in general. It feels like we're going back 25 years to when C code would all be structures, and you'd have functions that accepted the structures as the first parameters. Objects are superior to this pattern for many reasons including inheritance and polymorphism. It would be nice to keep those properties.

I'm not sure if it's possible, but I've opened an issue to discuss making that possible: #2328.

@shoyer
Copy link
Collaborator

shoyer commented Feb 28, 2020

Classes work totally fine in JAX as long as you don't mutate state. Even your exact code example runs, e.g., see this notebook for a working example:
https://gist.github.com/shoyer/0b3221ed0431befdfbfc9884e9353f8e

The reason why classes are awkward is that JAX's transforms like vmap don't necessarily work like you might expect if your functions return custom objects. For example, here's how you could auto-batch with JAX's functional interface:

keys = jax.vmap(random.PRNGKey)(jax.numpy.arange(3))
samples = jax.vmap(random.uniform)(keys)

With a class, you need to write something like:

rngs = jax.vmap(rng)(jax.numpy.arange(3))
samples = jax.vmap(RNG.uniform)(rngs)

Which is fine, I guess, but not very convenient.

Intuitively, one might expect that batching the creation of a class would pre-batch all of its methods, but definitely doesn't work right now (jax.vmap(rng)(jax.numpy.arange(3)).uniform() results in an error) and I'm not even sure the behavior is entirely well defined.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Mar 4, 2020

Just adding a comment in this thread, even though I gave a code example in the linked issue #2328.

I think a better approach would be to make the classes all optionally vectorized from the start, so the RNG class would accept an optional shape parameter too. Then you don't need to vmap it, but rather do something like this:

rngs = RNG(jax.numpy.arange(3), shape=(3,))
samples = rngs.uniform()

Would having this jax shape in all mutable objects allow us to do this?

Working with values and non-polymorphic functions seems unnecessarily frustrating compared with objects and methods.

@mattjj
Copy link
Collaborator

mattjj commented Mar 10, 2020

Thanks for the ideas!

IIUC the main proposal is to have a distinguished PRNG key type, e.g. for mypy type checking. That sounds like a good idea! It wouldn't require any API additions or changes. In fact, in the side-effects branch we prototyped exactly that. We plan to merge something along those lines in the next couple months.

There are also ideas around API changes or internal refactorings, but I find those less compelling:

  • We don't have any foreseeable-future plans to add other splittable PRNG hash functions yet, and IMO we shouldn't build anything motivated by that abstract goal until we know exactly the problem we're trying to solve.
  • I don't think rng, sub_rng = split(rng) is more verbose than rng, sub_rng = rng.split(), and IMO by being more OOP-y the latter is less clear about PRNG key splitting having value semantics.
  • Being closer to the NumPy PRNG interface is not a goal for JAX's PRNG, and NumPy's PRNG interface was not designed with the same considerations we have (especially around functional purity). Users who really want a NumPy interface might be able to build one on top.

Is my understanding of the main proposal (a distinguished PRNG key type, so keys aren't DeviceArrays but still work as they do now with JAX transformations) correct? If so, I propose we tweak the title of the issue, and leave it open until we land something along those lines.

WDYT?

@mattjj mattjj self-assigned this Mar 10, 2020
@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Mar 10, 2020

IIUC the main proposal is to have a distinguished PRNG key type, e.g. for mypy type checking.

The full proposal is to have a class type for PRNGs. You're right that one benefit is type-checking. Another benefit is abstraction and the ability to call methods on objects, which enables polymorphism. I think you may be underestimating he benefits of OO in writing clean code, but okay, let's cross that bridge when we get there. I may soon have a more motivating example.

the side-effects branch we prototyped exactly that. We plan to merge something along those lines in the next couple months.

Awesome! I don't totally understand it yet, but looking forward to it.

IMO we shouldn't build anything motivated by that abstract goal until we know exactly the problem we're trying to solve.

Fair enough.

I don't think rng, sub_rng = split(rng) is more verbose than rng, sub_rng = rng.split(), and IMO by being more OOP-y the latter is less clear about PRNG key splitting having value semantics.

Right, the only benefit is polymorphism.

If so, I propose we tweak the title of the issue, and leave it open until we land something along those lines. WDYT?

Please, go ahead and tweak the title. I am waiting to see if in the development of my project I run into any walls caused by not have polymorphism in JAX. If so, my plan was to flesh out a clearer proposal at that time. I agree with your philosophy of "crossing that bridge when we get there".

@NeilGirdhar NeilGirdhar changed the title Design suggestion: consider adding the class, jax.random.RandomState Design suggestion: consider adding jax.random.Generator Apr 27, 2020
@shoyer
Copy link
Collaborator

shoyer commented Apr 27, 2020

@NeilGirdhar the value returned by jax.random.PRNGKey is just an JAX array (though the exact form is an implementation detail), so the simple way to avoid recompilation is to avoid using static_argnums.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Apr 27, 2020

@shoyer After some more thought, you're right. I should avoid using static arguments. Sorry for the noise.

@NeilGirdhar
Copy link
Contributor Author

(Closing because Haiku appears to have a more elegant solution.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants