-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support custom PRNG implementations #6899
Conversation
fyi @LenaMartens: this change prompted lifting PRNG keys to a custom type. |
Summarizing some discussion with @LenaMartens:
|
I'm sorry to disappoint. The former can be implemented by only syncing the seed, but I was lazy and was synchronising the whole I'd be grateful if this use-case could be considered. |
Thanks for pointing these out. I think the Looking at the first expression you linked: key, _ = mpi.mpi_bcast_jax(key, root=root, comm=comm) If I understand correctly, The PRNG key class is registered as a pytree node in this PR, meaning that its instances can be flattened into arrays with Looking at the second expression: to generate "a PRNGKey depending on rank number and key," could you use |
Indeed that might be a good idea, thanks. (cc @dionhaefner )
Hmm. So you mean that instead of Still, if I can Thanks! |
666114f
to
9cebe89
Compare
@PhilipVinc – thanks, and yes, that sounds right. To clarify two things:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a canonical test for testing a custom prng function (either in api_test or random_test)? The intent would be to capture a small snippet describing how a user would supply a custom_rng generation function. We could even have the function just be a jax function that generates a bunch of zeros, or xors the input key with itself n number of times.
fc7ab85
to
791158d
Compare
If the design is clear, please add some notes in the PR description, so that people can then know how to review. |
jax/_src/prng.py
Outdated
return self.key.dtype | ||
|
||
|
||
def make_prng_key(seed: int) -> PRNGKey: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really looking forward to this change!
Just curious, but would it make more sense for this to be a class factory?
class PRNGKey:
@classmethod
def from_seed(cls: Type[T], seed: int) -> T:
return cls(_threefry_prng_key(seed))
This has two possible advantages:
- It would be one fewer thing to expose, and
- If you choose to add it to the abstract RNG base class, it would be part of the interface.
Also, thanks for taking care to add most of the type annotations. (There are a couple missing though, PRNGKey.__iter__
is probably going to cause MyPy errors.)
def __iter__(self) -> Iterator['PRNGKey']:
jax/_src/prng.py
Outdated
# TODO(frostig): remove if possible, otherwise declare necessary | ||
@property | ||
def shape(self): | ||
return self.key.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you decide to keep this, then shouldn't the shape of the key be self.key.shape[:-1]
? After all, this is what you've split the key into, and gives the shape of the vmapped function that this key works with. E.g., if you have vmap(f, in_args=(0,0))(rng, x)
, then rng.shape
's leading coefficient should equal x.shape
's leading coefficient. The "2" final coefficient seems like an implementation detail that a user of the class shouldn't have access to.
In my version of the Generator
class, I never ended up needing this function though 😄
36479bd
to
134a9da
Compare
Thanks to everyone for the early comments. Previous versions of this PR were an experimental draft that bundled in several shortcuts to avoid breaking existing tests. I found time to implement the design I had in mind, together with a config flag for backwards compatibility and various warnings about future changes. I've updated the PR description with design notes, including some remarks on temporary limitations as well. Many early suggestions are out of date or addressed by now. I'll go back and reply to those that still stand out. |
@PhilipVinc - You might notice that another workaround now for turning keys into arrays for use as primitive operands is by reading |
``random_bits``, ``fold_in``). | ||
""" | ||
|
||
impl: PRNGImpl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to intrude, but what's the point of using the Pimpl idiom in Python? In C++, the benefits are a compilation firewall and binary compatibility—neither of which apply to Python. Wouldn't it be simpler, more idiomatic, and more efficient to just have abstract methods?
This would
- simplify your flatten and unflatten methods by removing
self.impl
, - simplify your implementation by removing
PRNGImpl
, - simplify type annotations,
- make it easier to have a hierarchy of behavior if the need arises,
- and make it easier for users to define their own
PRNGKeyArray
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a natural question! I kicked around an idea like this earlier on.
I'm choosing to wrap-and-delegate because it directly expresses what I'm trying to enforce. PRNGKeyArray
is a class made to adapt a PRNG implementation, which is expressed on an individual key, into an "array" of keys. It's an internal definition, not part of the API, and we have no need in jax for a hierarchy of behaviors around it.
If you're writing custom PRNGs from outside of jax, you're welcome to supply the PRNG interface however you like. It can be a member of whatever fun type hierarchy you've set up in your code, or just a plain record of four functions. You'll notice it need not even be a proper prng.PRNGImpl
, but rather anything that duck-types as one. Hopefully that's plenty of flexibility, less for us to document, and less for anyone to learn.
To your specific points:
- simplify your flatten and unflatten methods by removing
self.impl
,- simplify your implementation by removing
PRNGImpl
,
These seem like minimal savings, so I'm not concerned.
- simplify type annotations,
How so?
- make it easier to have a hierarchy of behavior if the need arises,
I don't foresee such a need. We can always revisit if/when it comes up.
- and make it easier for users to define their own
PRNGKeyArray
.
I don't want them to!
What do you think? As someone who might use this, do these choices get in your way at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the exhaustive reply.
PRNGKeyArray
is a class made to adapt a PRNG implementation, which is expressed on an individual key, into an "array" of keys. It's an internal definition, not part of the API, and we have no need in jax for a hierarchy of behaviors around it.
PRNGKeyArray
is a member of the union KeyArray
, which is passed to the public functions in random.py
. So, PRNGKeyArray
is part of the API, isn't it? How should a user annotate the generator that's produced by seed_with_impl
and that's passed around to user functions?
About PRNGImpl
, my point is that it could just be folded into PRNGKeyArray since it doesn't appear that you use it on its own anywhere except to produce a PRNGKeyArray
. I do understand your point that you want to treat PRNGKeyArray
as a wrapper, but it only ever wraps one thing, and the thing it wraps is only ever used within the wrapper.
we have no need in jax for a hierarchy of behaviors around it.
Fair enough.
simplify type annotations,
How so?
The functions in PRNGImpl
are annotated as Callable
. It would be nice to just define them as abstract methods so that they have full signatures. I guess I don't see the point of using function-valued members that you set instead of abstract methods that you implement in subclasses. Are you avoiding inheritance for some reason?
make it easier for users to define their own PRNGKeyArray.
I don't want them to!
One possible user-defined PRNGKeyArray subclasss you might find is a hypothetical Haiku HaikuKeyArray
. It would simplify Haiku code to not have to extract keys from a PRNGKeySequence
, and instead use the HaikuKeyArray
, which would be passable to all the JAX random functions. It would be mutable, and it would hold a PRNGKeyArray
internally, which it would split, just like the PRNGKeySequence
currently does. Why make it more difficult to implement such a subclass?
do these choices get in your way at all?
They do not! My real motivation is to throw away my Generator
class here with a superior solution from the JAX codebase 😄 I just thought it would be nicer to have one class rather than two, and to use abstract methods rather than function-valued members. In my code, I pass around Generator
everywhere. I had hoped for a similar solution like JaxGenerator
or something like that 😄
Anyway, I think I've expressed my points. Thanks for the reply. I'm sure whatever you choose will work great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the exhaustive reply.
Of course – thanks for the same. I better understand your points now. It seems like there are pros and cons either way. I went ahead with this PR as is, but we can always adapt in the future, and it will be useful to have this discussion to refer back to.
PRNGKeyArray
is a member of the unionKeyArray
, which is passed to the public functions inrandom.py
. So,PRNGKeyArray
is part of the API, isn't it? How should a user annotate the generator that's produced byseed_with_impl
and that's passed around to user functions?
You're right – the type is publicly visible. What I had in mind was: users are not meant to instantiate this class directly, nor to interact with with it other than by (a) the array-like methods (which I want no one to override) and (b) passing it to random
functions. It's supposed to simply look like an array for the most part. Good correction to my unqualified remark regardless.
About
PRNGImpl
, my point is that it could just be folded into PRNGKeyArray since it doesn't appear that you use it on its own anywhere except to produce aPRNGKeyArray
. I do understand your point that you want to treatPRNGKeyArray
as a wrapper, but it only ever wraps one thing, and the thing it wraps is only ever used within the wrapper.
I see now. Yes, that would save some indirection. On the other hand, PRNG implementations would involve slightly more wrapping/unwrapping boilerplate that the class currently handles.
They do not!
Excellent. That's the most important part to me. Everything else can be revised later if/as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Some small comments/request for clarification.
def dtype(self): | ||
# TODO(frostig): remove after deprecation window | ||
if config.jax_enable_custom_prng: | ||
raise AttributeError("'PRNGKeyArray' has no attribute 'dtype'") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might have already explained this to me, but why do we want to remove the ability to do key.dtype
? Seems like in the current implementation this can only be unit32
, regardless of the PRNGImpl
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what to return as the dtype
? If the purpose is to be an array of opaque keys as elements, then uint32
seems incorrect (that's the dtype of the data buffer representation of an individual key). I figured that dtype
is one of the ways in which this array type is not a numpy or jax device array. Perhaps this is too cautious, and if so I'd be happy to revisit it.
@property | ||
def _shape(self): | ||
base_ndim = len(self.impl.key_shape) | ||
return self.keys.shape[:-base_ndim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we assert that self._shape
is always equal to self.impl.key_shape
? Do we ever expect them to be different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two are different: _shape
should be the shape of the encompassing "array of keys," whereas impl.key_shape
is the shape of an individual key. The underlying buffer holding all the key data is an array of shape (*self._shape, *self.impl.key_shape)
.
elif _arraylike(key): | ||
if config.jax_enable_custom_prng: | ||
warnings.warn( | ||
'Raw arrays as random keys to jax.random functions are deprecated. ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this warning not trigger for every user if the default is custom_prng=False
? Maybe there should be a hint on how to resolve the warning if that is the case?
More generally, do we expect users to turn on this custom_prng
flag themselves? As I understand it, it's only a small group of users who rely on the keys being jnp
arrays, and for most users this will be a no-op change, so I'm not sure they should be aware of the flag. I might be wrong though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config.jax_enable_custom_prng
flag defaults to false and the warning is guarded by it, so that (I think) it should only trigger when the flag is explicitly enabled.
In general, I tried to ensure that no new warnings would appear when the flag is at its default value (meaning no upgrade). I checked this visually by running random_test.py
and api_test.py
.
I don't expect end users to turn this flag on. I hope to work with some neighboring libraries to upgrade by turning it on, seeing some warnings (rather than errors), and making changes to handle them. Then we can deprecate the flag altogether and remove all these checks and warnings, though I will probably flip the default value of the flag to True
for a while first beforehand.
Does that seem right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems good, thank you for trying to do this smoothly.
Speaking for NetKet, however, we enforce semver on our dependencies so if you were to follow it too you could also avoid doing all this.
We are upper bounding the minor version of jax
to 0.2
, so if you were to interpret this change as a breaking change (which IMO it is) and bump the next version of jax to 0.3
instead of 0.2.X
, pip would avoid installing jax 0.3 with current versions of netket, and we would have the time to upgrade our code and release a new version that is compatible with jax 0.3
.
I don't know if jax is following semver or not, however.
It really does make things easier for dependent packages.
A PRNG implementation is determined by a key shape and a set of basic functions on such a key: seed, split, random_bits, and fold_in. A PRNG implementation can then by lifted to an array-of-keys-like object. Namely, a new internal pytree class PRNGKeyArray wraps the implementation and maintains an array of keys of the right shape. This array-like object is the new "key" that gets passed around the various functions in the public random API (e.g. `random.uniform`, `random.normal`, ...). So the PRNGKeyArray class really serves two purposes at once: 1. To adapt key implementations into "arrays" of such keys. 2. To carry a reference to the PRNG implementation around and delegate back to it from the functions in random.
Some tests check the behavior of the random bit generator---in particular the default threefry implementation---and some check the behavior of samplers. Separate them into different test classes.
Now that `LaxRandomTest` only tests random functions and not any specific PRNG, it can be reused to test random functions under different PRNGs.
…conditionally Introduce a config flag for upgrading to a world of custom PRNGs. The flag defaults off, so that we can introduce custom PRNGs into the codebase and allow downstream libraries time to upgrade. Backwards compatible behavior is meant in an external sense. This does not mean that our code is internally the same any longer.
Some call this, apparently.
Nice work on this! |
Today jax offers a single splittable PRNG implementation (threefry 2x32) and models its key as a
uint32[2]
. A batch of keys is auint32
array whose trailing dimension comprises the length-2 keys.With this PR, we have two goals:
jax.random
to work even when seeding, splitting, and random bit generation is defined by user code.An added challenge of this upgrade is that the previous "raw array" key model has been around for a long time, and many downstream uses have overfit to it. To varying extents, flax, haiku, trax, and TFP all validate user-provided random keys as
uint32[2]
arrays. We'd like to obviate the need for such checks. Other downstream code creates keys directly with hard-coded values (including a few of our own tests), reshapes key batches (including ourgamma
primitive implementation), or casts other program values to keys. For safety, we would like to disallow such operations.I'm introducing this change gradually to allow time for downstream fixes. I've tried to do this by recovering existing behavior under a flag that is, for now, on by default.
design
Again, two goals:
jax.random
.Item 1 is scoped only to the JAX Python level; changing key representation in the Jaxpr IR is not a goal. This PR requires that keys remain backed by
uint32
arrays, now of arbitrary specified shape. We can broaden that later as needed. They will therefore appear asuint32
arrays in Jaxpr.A custom PRNG implementation amounts to a key type
K
(determined by a shape) plus a handful of functions operating on such a key:An example, also used for unit tests introduced in this PR, is the following cartoon PRNG built from two threefry2x32 keys. To define it:
And to use it:
Note that the
shape
field of thePRNGKeyArray
changes from the scalar()
to(10,)
when it issplit
10 ways and to(3,)
when it is broadcast byvmap
.Internally, this works as follows:
impl
with a key typeK
like the one above.PRNGKeyArray
wraps the given PRNG implementationimpl
and lifts it to an array-like interface. This has roughly the semantics of an array ofK
-typed elements. Under the hood it is auint32
array whose trailing dimensions correspond toK
. It is array-like only in select ways, for instance supporting indexing and iteration but not numerical operations. Expressions such askey, subkey = jax.random.split(key)
remain valid butkey + 1
orjnp.exp(key)
are no longer allowed.PRNGKeyArray
carries the implementationimpl
around throughjax.random
functions, which then delegate to it forseed
,split
,random_bits
, andfold_in
.Note that all new indirection happens at the Python level. This design makes no changes to Jaxpr, although the extra indirection could hopefully help make element-type changes in Jaxpr easier in the future. I discussed this with @LenaMartens and we think it will help clear the path for some Jaxpr-level RNG checking ideas in particular.
I've separately started thinking about whether to generalize the array-like wrapping concept to allow for lightweight custom-element-type arrays more broadly. If RNG keys are one special case, then arrays of
unit
,float0
, enums, or arbitrary pytrees could be others. I mean "lightweight" both in the sense of "easy to set up" as well as "without necessarily extending Jaxpr."temporary limitations
Both
random.gamma
, andjax.poisson
are not yet implemented with custom PRNGs, since these have threefry2x32-specific implementations today that need to be generalized. I have themraise NotImplementedError
when given a custom PRNG for now.When fully enabled, this change breaks several of our neighboring libraries. To allow for smoother upgrades, I've retained existing behavior under a flag. Doing so introduces extra code and test complexity in the interim.
PRNGKeyArray
may want to support more array-like methods in the future. I implemented an obvious initial handful for now.Fixes #5081.
Relatedly addresses parts of #2294.