-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding random number generation API #431
Comments
I'd like to quote @rgommers from there (scikit-learn/scikit-learn#22352 (comment)) because I feel it clarified a lot:
|
Yeah was just summarizing the discussion up to that point. Stateless isn't really the right term for this I think. A more correct description might be functional as many functional languages use a similar strategy. IOW something like this
IOW the user is responsible for tracking the state in this model and threading it through subsequent random number generation function calls. The other model is object oriented like NumPy's where state lives in some object that gets mutated each time a random number is generated. |
We had a previous discussion on PRNG APIs, and IIRC no one was attached to the legacy The new NumPy API seems a little more user-friendly than the JAX APIs, and offers more functionality too. And it's no less safe - I think it's just as well possible to shoot yourself in the foot with JAX: just forget to manually create a new sub-key once. This is more of a philosophical difference: there's only one way of doing things in JAX, which is more verbose but the same for serial and parallel execution - while NumPy has a more concise way for serial (its default mode), but the user must remember to use a second method (equivalent to One advantage of the JAX API is that there's only one way to do things, while NumPy has multiple ways to deal with parallelism - so more functionality, but also harder to understand or standardize. That may be partly due to providing multiple PRNG algorithms though, while JAX provides only the Threefry algorithm (not entirely true, there's a second experimental one - XLA Random Bit Generator). PyTorch for example provides Philox (which JAX also could have used) and MT19937. Now a standard doesn't have to deal with exact reproducibility across libraries, but it should allow a design that allows libraries to choose their own algorithms. I think this works with either JAX's The main issues for standardization I see are:
tl;dr this will not be easy to standardize |
Thanks for the detailed write up Ralf! 🙏 FWIW Dask is already working on adopting the new NumPy API ( dask/dask#9038 ) It looks like CuPy already did this in 9.0.0 with PR ( cupy/cupy#4177 ) (though Leo should feel free to correct me) That all being said, maybe it is worth asking if a subset of NumPy's API might be easier to adopt and if so what that looks like. Also the other important question here is what API is going to be most useful for downstream users. Started with scikit-learn as they make a lot of use of random number generation and creating some usable API with scikit-learn would be a win. Though maybe there are other downstream libraries that would make sense as well (perhaps statsmodels? others?). |
@leofang While there is no stateless PRNG, HW-based true random number generators have empty state (MKL's NONDETERM basic random number generator) and these should be supported by the spec as well. |
To get a better feel for the tradeoffs, here is code for a couple of things:
For all the below code in a single gist, see here. import secrets
import multiprocessing
import numpy as np
import jax
USE_FIXED_SEED = False
if USE_FIXED_SEED:
seed = 38968222334307
else:
# Generate a random high-entropy seed for use in the below examples
# jax.random.PRNGKey doesn't accept None to do this automatically
seed = secrets.randbits(32) # JAX can't deal with >32-bits
# NumPy serial
rng = np.random.default_rng(seed=seed)
vals = rng.uniform(size=3)
val = rng.uniform(size=1)
# NumPy parallel
sseq = np.random.SeedSequence(entropy=seed)
child_seeds = sseq.spawn(4)
rngs = [np.random.default_rng(seed=s) for s in child_seeds]
def use_rngs_numpy(rng):
vals = rng.uniform(size=3)
val = rng.uniform(size=1)
print(vals, val)
def main_numpy():
with multiprocessing.Pool(processes=4) as pool:
pool.map(use_rngs_numpy, rngs)
# JAX serial (also auto-parallelizes fine by design)
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key) # this one could be left out, but best practice is probably to always use `split` first
vals = jax.random.uniform(subkey, shape=(3,))
key, subkey = jax.random.split(key) # don't forget this!
val = jax.random.uniform(subkey, shape=(1,))
# JAX parallel with multiprocessing
def use_rngs_jax(key):
key, subkey = jax.random.split(key)
vals = jax.random.uniform(subkey, shape=(3,))
key, subkey = jax.random.split(key)
val = jax.random.uniform(subkey, shape=(1,))
print(vals, val)
def main_jax():
key = jax.random.PRNGKey(seed)
key, *subkeys = jax.random.split(key, 5) # gotcha: "5" gives us 4 subkeys
with multiprocessing.Pool(processes=4) as pool:
pool.map(use_rngs_jax, subkeys)
if __name__ == '__main__':
# JAX does not work with the default `fork` (due to internal threading)
multiprocessing.set_start_method('forkserver')
print('\nNumPy with multiprocessing:\n')
main_numpy()
print('\n\nJAX with multiprocessing:\n')
main_jax() JAX does seem to have a few gotchas with seed creation, it can't deal with high-entropy seeds apparently (at least in 0.2.27, released 18 Jan 2022):
A JAX-style API with """
Implement `jax.random` APIs with NumPy, and `numpy.random` APIs with JAX.
The purpose of this is to be able to compare APIs more easily, and clarify
where they are and aren't similar.
"""
import secrets
import multiprocessing
import numpy as np
import jax
USE_FIXED_SEED = False
if USE_FIXED_SEED:
seed = 38968222334307
else:
# Generate a random high-entropy seed for use in the below examples
# jax.random.PRNGKey doesn't accept None to do this automatically
seed = secrets.randbits(32) # JAX can't deal with >32-bits
def PRNGKey(seed):
"""
Create a key from a seed. `seed` must be a 32-bit (or 64-bit?) integer.
"""
# Note: selecting a non-default PRNG algorithm is done via a global config
# flag (not good, should be a keyword or similar ...)
seed = np.random.SeedSequence(seed)
rng = np.random.default_rng(seed)
key = (seed, rng)
return key
def split(key, num=2):
"""
Parameters
----------
key : tuple
Size-2 tuple, the first element a `SeedSequence` instance, the second
containing the algorithm selector.
num : int, optional
The number of keys to produce (default: 2).
Returns
-------
keys : tuple of 2-tuples
`num` number of keys (each key being a 2-tuple)
"""
seed, rng = key
child_seeds = seed.spawn(num)
keys = ((s, rng) for s in child_seeds)
return keys
def uniform(key, shape=(), dtype=np.float64, minval=0.0, maxval=1.0):
seed, rng = key
# Creating a new Generator instance from an old one with the same
# underlying BitGenerator type requires using non-public API:
rng = np.random.Generator(rng._bit_generator.__class__(seed))
return rng.uniform(low=minval, high=maxval, size=shape).astype(dtype)
def use_jaxlike_api(key=None):
if key is None:
key = PRNGKey(seed)
key, subkey = split(key)
vals = uniform(subkey, shape=(3,))
key, subkey = split(key) # don't forget this!
val = uniform(subkey, shape=(1,))
print(vals, val)
def use_jaxlike_api_mp():
key = PRNGKey(seed)
key, *subkeys = split(key, 5)
with multiprocessing.Pool(processes=4) as pool:
pool.map(use_jaxlike_api, subkeys)
if __name__ == '__main__':
# JAX does not work with the default `fork` (due to internal threading)
multiprocessing.set_start_method('forkserver')
print('\n\nUse JAX-like API (serial):\n')
use_jaxlike_api()
print('\n\nUse JAX-like API (multiprocessing):\n')
use_jaxlike_api_mp() A couple of thoughts:
One other thing to point out: the JAX docs comparing to NumPy are wildly outdated/unfair, they use the non-recommended (global state) way of using the legacy API. In general,
My tentative conclusions based on the above:
This is all a little nontrivial, so let me ping a few folks for input: @rkern for design/implementation thoughts and whether I missed anything important related to the NumPy implementation. @shoyer, @apaszke for thoughts from the JAX side. |
I'll need to spend more time to read the whole thread, but I will add here that we have always had a plan to lift |
Specifically, this code: # NumPy parallel
sseq = np.random.SeedSequence(entropy=seed)
child_seeds = sseq.spawn(4)
rngs = [np.random.default_rng(seed=s) for s in child_seeds] would become: rngs = rng.spawn(4) |
Thanks - I think that would indeed be quite helpful! |
Sorry for a bit of a tangent here, but is calling rng1 = rng.spawn(1)
rng2 = rng.spawn(1)
...
rngn = rng.spawn(1) This can come up when new children processes are created/destroyed on-demand (IOW autoscaling). |
Yes it is. More importantly:
is also valid. You can spawn any children to get more independent streams. There will never be a collision (at least not within reasonable probabilities). |
I went searching through some older notes and found this from @alextp: "I think the direction in TensorFlow is to follow JAX. Have functions for stateless random generation. They compose well. But they are not ergonomic. Layer on top of this will be something stateful. Once you do this, you introduce checkpoints etc, for determinism. So I would OK to add stateless API, but wouldn't be okay to add stateful one." That's from more than a year ago, so I don't know if it has been implemented like that, or anything changed in the meantime in TensorFlow. Maybe @edloper you can tell us? |
I'm perfectly content with having 0 PRNG APIs in the standard (and far prefer 0 in the standard to having 2 in the standard). It seems like there is significant variety in what different communities need and want. |
I think it's equally straightforward to write an explicitly stateful RNG system like numpy.random.Generator using JAX. Here's a prototype: import jax
class JaxGenerator:
def __init__(self, state):
self.state = state
def uniform(self):
self.state, key = jax.random.split(self.state)
return jax.random.uniform(key)
def spawn(self, count):
self.state, *keys = jax.random.split(self.state, count + 1)
return [JaxGenerator(key) for key in keys]
def __repr__(self):
return f'{type(self).__name__}(state={self.state})'
def jax_default_rng(seed):
return JaxGenerator(jax.random.PRNGKey(seed))
rng = jax_default_rng(0)
print(rng.uniform()) # 0.10536897
print(rng.uniform()) # 0.2787192
print(rng) # JaxGenerator(state=[2384771982 3928867769])
rng2 = JaxGenerator(rng.state)
print(rng2) # JaxGenerator(state=[2384771982 3928867769])
rngs = rng.spawn(2)
print(rngs) # [JaxGenerator(state=[1777981902 3244208681]), JaxGenerator(state=[ 669635267 2816531647])] From a JAX design perspective, stateful RNGs like this are not encouraged, because JAX's function transforms will break if you pass stateful objects into them. But you can still safely use this sort of API with JAX, as long as you're careful to create Generator objects inside pure functions, e.g., @jax.jit
@jax.vmap
def batched_random_uniform(seed):
return jax_default_rng(seed).uniform()
print(batched_random_uniform(jnp.arange(5)))
# DeviceArray([0.10536897, 0.12568676, 0.4336251 , 0.47652578, 0.7844808 ], dtype=float32) Or explicitly keeping track of updated RNG state: @jax.jit
def explicit_state_random_uniform(state):
rng = JaxGenerator(state)
sample = rng.uniform() # must happen *before* calculating new_state
new_state = rng.state
return new_state, sample
state = jax.random.PRNGKey(0)
state2, sample = explicit_state_random_uniform(state)
print(state2, sample) In fact, Haiku, which is one of the most popular neural net libraries in JAX, does something very similar with haiku.next_rng_key. These stateful interfaces are perhaps easier to misuse with JAX's functional transforms than pure functions, but are still a major improvement over global state. |
To bridge the gap between NumPy's and JAX's random number APIs, I would suggest slightly extending NumPy's API so it's easier to be explicit about state. Namely, we should add the Here's my minimal wrapper of NumPy's RNG API with these extensions, to match the JAX API above: import numpy as np
class NumpyGenerator:
def __init__(self, state):
# TODO: avoid initializing this dummy Generator/BitGenerator?
self._rng = np.random.default_rng()
self.state = state
def uniform(self):
return self._rng.uniform()
def spawn(self, count):
entropy = self._rng.integers(2**63) # TODO: better entropy?
sseq = np.random.SeedSequence(entropy)
child_seeds = sseq.spawn(count)
return [numpy_default_rng(seed) for seed in child_seeds]
@property
def state(self):
return self._rng.bit_generator.state
@state.setter
def state(self, value):
self._rng.bit_generator.state = value
def __repr__(self):
return f'{type(self).__name__}(state={self.state})'
def numpy_default_rng(seed):
return NumpyGenerator(np.random.default_rng(seed).bit_generator.state)
rng = numpy_default_rng(0)
print(rng.uniform()) # 0.6369616873214543
print(rng.uniform()) # 0.2697867137638703
print(rng) # NumpyGenerator(state={'bit_generator': 'PCG64', 'state': {'state': 143609658456486183636066271097634410721, 'inc': 87136372517582989555478159403783844777}, 'has_uint32': 0, 'uinteger': 0})
rng2 = NumpyGenerator(rng.state)
print(rng2) # NumpyGenerator(state={'bit_generator': 'PCG64', 'state': {'state': 143609658456486183636066271097634410721, 'inc': 87136372517582989555478159403783844777}, 'has_uint32': 0, 'uinteger': 0})
rngs = rng.spawn(2)
print(rngs) # [NumpyGenerator(state={'bit_generator': 'PCG64', 'state': {'state': 164868448360684748498847325894109072011, 'inc': 241524822143570234404080558697197945801}, 'has_uint32': 0, 'uinteger': 0}), NumpyGenerator(state={'bit_generator': 'PCG64', 'state': {'state': 219573644955839246335449654370252341036, 'inc': 175512849395095609630857841553467033115}, 'has_uint32': 0, 'uinteger': 0})] |
It's nice to see the alternative wrapping for perspective
Am curious how much work it would be to add these to NumPy? |
Not much. |
One note regarding JAX:
JAX disables 64-bit data types by default, but you can enable this if you wish to use them: If you enable 64-bit values, then 64-bit seeds will be fine. |
I'm happy to update these if you think it would be helpful. In my experience, the global default rng is how most users use numpy's random APIs, despite more recent changes to recommendations in numpy's docs, so it's a useful way to introduce how JAX differs. As for the JAX discussion, whether the seed is global or a mutated |
I think the recommendation from the TF side would still be to use the stateless API. (As mentioned above, a stateful API can be layered on top of the stateless one, if desired.) |
We can build shim APIs across either implementation to get stateless or stateful APIs, but the way that each shim API needs to be implemented has its own costs, constraints, and tradeoffs. I would suggest exploring those implementation strategies and how well they actually support a set of use cases before relying on the mere existence proof of shim APIs as a reason to go ahead and standardize on one API instead of another. What would the shim APIs actually allow us to do? Let's say a stateless API becomes the standard. Vice versa, we can definitely make a stateless API on top of the current implementation of Again, I'm pretty happy for there to be no API standard on this subject. I would be interested in talking more on how to hand off from one style of API to the other, though. There's room in the |
I vote for presenting them here, please. |
It's, uh, getting long. Not because my list of issues is long, just that I am long-winded and am including a lot of background. |
ApologiesThis is long, and everyone has my apologies for that, but I want to make sure the background is laid out. IntroductionSo first off, I want to say that the issues I am going to lay out don't make JAX's PRNG a bad one for all purposes. JAX has a fairly specific usage profile, and these are limitations that can be usefully lived within, if one is willing to. But these are limitations that I would hesitate to propagate up to general use amongst all Array API users, particularly in contexts where one doesn't get the benefits paid for by those limitations. I will mention a number of issues that I have with specific implementation choices that are in principle addressable. I don't really count those as determinative (they can indeed be fixed while still maintaining the style of API). But I do think they can be taken, as a whole, that imposing this style of API places some serious constraints on implementations that is probably out of scope of this standardization effort. Reviewing JAX's DesignTo review, JAX's PRNG API is based around explicit splitting of the PRNG state and then pure functions to generate possibly-large arrays of data from the leaf keys. The cipher_text = threefry2x32(cipher_key, plain_text) This primitive is a keyed bijection: given a fixed For the core PRNG algorithm to draw random bits, we use a simple incrementing counter as each plaintext block. This is nice for GPUs because we can create that counter array and then have the GPU run In particular, instead of incrementing the counter in 64-bit blocks, it creates a 32-bit counter array, splits it in half to use the first half as the upper 32-bit word and the second half as the lower 32-bit word. That's a bit wasteful (you could instead just use So far, so good. Despite that quirk, I have only one qualm with the core PRNG scheme for drawing bits in the The mechanism that JAX uses to def left(prng_key):
return jax.random.split(prng_key)[0]
def right(prng_key):
return jax.random.split(prng_key)[1] None of the following really depend on restricting ourselves to that, but it's handy to talk about things. So we use the current It has nice properties for cryptographic purposes, but it has some significant drawbacks for PRNGs used for scientific and statistical purposes. Damn StatisticsWhile there are few guarantees without knowing something about the structure of the mapping, you can calculate some useful statistics about the population of random non-invertible mappings. Particularly since the However, such state spaces are inherently biased and non-uniform, and this doesn't go away as the state space increases. The states that follow states which have in-degrees greater than 1 will be overrepresented if you repeat with different initial seeds. States on small cycles will also be overrepresented. Small cycles will have big trees attached to them. A large fraction of states (on the order of half) are not reachable from any other state; you have to start with them as the initial seed to ever observe them. This is particularly an issue when you reduce the initial seeding space to just 32 bits, as is the default configuration of JAX. I will never recommend a PRNG based on non-invertible mappings for scientific use. So far, I've talked about iterating x = rng.uniform(0, 1)
y = rng.uniform(10, 20)
z = rng.uniform(-1, 1)
# -->
key, subkey = split(key)
x = uniform(subkey, 0, 1)
key, subkey = split(key)
y = uniform(subkey, 10, 20)
key, subkey = split(key)
z = uniform(subkey, -1, 1) So the iteration of So I did construct a PRNG out of repeated iteration of PractRand output for
|
If I understand Robert's post above correctly, you're trying to convince us this is one of the rare cases where API design is tightly coupled to the underlying algorithm. I'd like to add one more argument based on Robert's point (against random number standardization): In practice, the PRNG implementations available to different devices (CPUs, NVIDIA GPUs, AMD GPUs, Intel GPUs, Google TPUs, ...) are very likely different. (@emcastillo from CuPy/PFN has first-hand experience for this pain, as NumPy and CuPy have totally different PRNGs, and in this case CuPy isn't strictly speaking a drop-in replacement of NumPy.) It then follows that the API standardization must be either PRNG-ignorant (= just give me some random numbers, I don't care how you generate it), or exclude PRNGs from the standard (= define an API to accept vendor-specific PRNGs, but leave out what PRNG we must cover in the standard). Either assumption must hold in order to proceed. |
It's probably also worthwhile to tie this back to how downstream users would leverage this API. Going back to scikit-learn, many APIs take a Given the current usage of the object API by these libraries, it is relatively straightforward to see how they would adopt an object API if added to the spec. What is less clear is how they would leverage a functional API. If we are interested in evaluating that option, it would be worth seeing how that API would be used in a downstream library. Should add we probably want to look at something a bit more general than JAX's current API given it is tied to a specific class of PRNGs as pointed out in a few places in this issue already. |
It's not coupled one-to-one, but that choice of API does eliminate a wide swathe of PRNG algorithm choices, and there are only some tricksy options left over. What I am particularly arguing against is the claim that if we standardize on JAX's style of stateless API that we can just build a stateful API on top of it again for those people that like that API. While this is true, it is not usefully true. The reconstructed stateful API (built on top of the standardized JAX-style stateless API) does not restore the full range of PRNG options that we had before. And I think that there are good reasons to avoid the constraints of the JAX-style stateless API in the cases where we're not getting the benefit of JAX's other capabilities along with it. The key issue is not so much the raw algorithm that is used inside of the distribution methods ( The key issue (so to speak) is how to handle the data flow of the PRNG state. JAX's needs place strong requirements on this data flow in order to get corresponding benefits. Other environments don't have those benefits and thus don't have those requirements. Embedding those strong requirements in the standard imposes those costs on everyone. It's not just the syntax sugar of how the |
Another way to think about it is that all PRNG schemes have a certain finite amount of safety margin, and it is a consumable resource. Just serially drawing arrays of numbers from it consumes a small amount of that safety margin. Splits, no matter how implemented, consume a lot of that safety margin. My general recommendation is to only split when you have to, at the place where you need parallelism. In JAX programs, that's everywhere, because JAX is awesome at enabling incredible amounts of parallelism without having to explicitly code it. A splitting-central PRNG scheme fits very well in that environment. The parallelism pays back what you spent in terms of the lowered safety margin. But for most of the other Array API implementations, forcing that amount of splitting is forcing everyone to live with that impaired safety margin for no corresponding benefit. The issue is not merely the ergonomics of the API. Simply wrapping a stateful API around a splitting-central stateless API does not restore the safety margin built into the more typical stateful PRNG implementations. |
Hi @rkern - thanks for those comments. Just to be clear, the doc you linked to is not really "documentation" per se, but rather a years-old design doc meant to lay-out the motivation for JAX's initial PRNG design (which has evolved since then, and will continue to evolve). For that reason, I don't think I will be updating it, but rather will add a disclaimer at the top making its intent more clear. How does that sound? |
There might be a couple of other opportunities to standardize something, but it's not particularly clear to me what it would actually enable across implementations as different as numpy and JAX, say. So for example, we could say, informationally, that there are 3 basic flavors of PRNG state flow that an implementation could have, but the standard doesn't specify any one or any of the details about each.
But it could standardize on a basic list of methods/functions that an implementation ought to have, along with the semantics of the other arguments not related to the PRNG state flow, and allowing for extra implementation-specific arguments (dask adds This is not likely actually enable one to write significant backend-agnostic code. So the Array API standard might not be the best place for that. Maybe a SPEC is a better instrument? The Array API does have the advantage of having convened the right people. |
@jakevdp Sure, but it's also in the tutorial. |
Thanks for pointing that out. To be honest, it's news to me that numpy does not provide a sequential-equivalent guarantee in it's pseudo-random values. I've spent the better part of the last couple decades assuming it did (I hope none of my code actually depends on that assumption...) |
Is there a place in the numpy docs that mentions the lack of a sequential equivalent guarantee in pseudo-random numbers? I'd like to link to it in the discussion. If not, I can link to this thread. |
On Wed, Jun 8, 2022 at 2:46 PM Jake Vanderplas ***@***.***> wrote:
Is there a place in the numpy docs that mentions the lack of a sequential
equivalent guarantee in pseudo-random numbers? I'd like to link to it in
the discussion. If not, I can link to this thread.
This is the closest I found: https://numpy.org/neps/nep-0019-rng-policy.html
…-Edward
—
Reply to this email directly, view it on GitHub
<#431 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABMFVDFX6UK2GJFQV7I4YELVODTBLANCNFSM5V3YMQ5Q>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
The "Compatibility Guarantee" in the
This language was specifically about compatibility about different versions of numpy (though of course, it also holds true inside of a process). That's the only guarantee we've ever provided (and even that we disclaim for It seems reasonable to explicitly disclaim that not every combination that someone might think ought to be equivalent under some theory actually is actually implemented to be so. People do expect it from time to time and "confirm" it to themselves when they try something out, though we've never promised it. |
Thanks for the pointers - in that case I think I will link to this thread, because it seems more direct than linking to a long doc and pointing out the omission of the topic in question. |
Are there actually large reasons why the NumPy API (with It seems that there are the 3 "ways" of using a PRNG that Robert listed also above?:
Everyone provides 3, some PRNGs (as NumPy) can probably not provide 2 but don't need it (NumPy never evaluates lazily – it might only be a curious addition for parallelization). JAX need 2 (but is currently "missing" it – i.e. is using the less optimal "always spawn" scheme). But I don't think it has to be user-facing API? If you write:
in JAX or NumPy, whether that draws the numbers directly sequentially, or does so lazily (with some advancing or splitting scheme) hardly matters? What matters is that the user facing API is "sequential", because writing it differently would be bad for many PRNG schemes. What I don't quite understand yet, is why a stateful API that provides:
is not a reasonable start. It would be bad for NumPy to provide a JAX-style stateless API, but why can't JAX provide a sequential stateful API (even if it is not necessarily a sequential RNG generation internally).
gives the same as:
NumPy will do this (often or always?) but it doesn't seem necessary for the end-user API? The difficulty I see a bit is that some implementations may look like they provide guarantees like the above concatenation and others will not provide those same guarantees. |
It's a welcome surprise to have this thorough review of the JAX PRNG! Thanks @rkern for the really nice detailed look in particular. I'm only arriving to this thread now and still catching up, but as a quick note for now: we happen to be actively working on some changes to the
We indeed do not need to lay out the counter values exactly in the way that we do, so you'll see changes to that end. (I've observed in other contexts over time that the particular split/concat scheme we use isn't ideal, but I only picked up work on this again last week.)
As of recently, I'd say that the key size—and in fact the entire hash function or base generator—is not an essential or timeless choice for us. It is what we implemented initially, and it remains the default for now. But as of jax-ml/jax#6899 we have a means of replacing the underlying bit generator. We use this internally to experiment with and offer other bit generators. I opened jax-ml/jax#7676 a while back precisely to track the introduction of a 128-bit generator. We also have plans to make it possible for users to plug in an arbitrary generator of their own, and to involve differently-backed PRNGs in the same process (all tracked at jax-ml/jax#9263). We might always choose to change our default away from the current threefry2x32 hash at some point. So, where possible in the current discussion, it may be useful to assume that the base generator can change to meet the needs of the machine, application, etc. |
@froystig That's all good to hear! Like I said in my introduction, I didn't consider any of those details as immutable black marks against JAX or the overall design, but they did serve as a way to talk about the similarities with |
I don't know all of JAX's rewriting capabilities, but I don't think it can handle whatever it needs to do to make the data1 = rng.uniform()
data2 = rng.uniform() What it would have to do here is recognize that And even if JAX did have technical capability of doing so, I'm not sure that they want their users to mix mental models like that. It's hard enough to teach people to use one consistently. |
@rkern Absolutely – I understood that and it makes sense. I only meant to share some related thoughts and work that's in progress, mostly as an aside. Your comments are super valuable, and this thread is a rare opportunity for us on JAX. Thank you! |
I'm the author of Unlike JAX, TF only supports counter-based RNG algorithms and has no plan to support non-counter-based ones. |
Yes, I can confirm that splitting is more profligate of the state space than the counter increment, especially when you constrain yourself to 64-bit |
From a user perspective, specifically someone writing libraries which want to support arbitrary array types, I don't see the problem with carrying the random state separately from the random namespace. It is not the slickest interface, but it can be implemented today: # file: random_numpy.np
def standard_normal(rng, shape=()):
return rng, rng.standard_normal(size=shape)
def poisson(rng, lam, shape=()):
return rng, rng.poisson(lam, size=shape)
# file: random_jax.py
from jax import random
def standard_normal(key, shape=()):
key, subkey = random.split(key)
return key, random.normal(subkey, shape)
def poisson(key, lam, shape=()):
key, subkey = random.split(key)
return key, random.poisson(subkey, lam, shape)
# file: test.py
def random_namespace(rs):
import sys
if 'numpy' in sys.modules:
import numpy as np
if isinstance(rs, np.random.Generator):
return __import__('random_numpy')
if 'jax' in sys.modules:
import jax
if isinstance(rs, jax.Array):
return __import__('random_jax')
def my_function(lam, rs):
random = random_namespace(rs)
rs, rv1 = random.standard_normal(rs, shape=(4,))
rs, rv2 = random.poisson(rs, lam, shape=(4,))
return rv1 + rv2
import numpy as np
rng = np.random.default_rng()
print(my_function(1.0, rng))
# [ 1.25035394 -0.11511349 1.87203598 2.55088409]
import jax
key = jax.random.PRNGKey(42)
print(my_function(1.0, key))
# [0.43244982 0.8856307 0.06793922 1.4641162 ] For anything more complicated than sampling random variates, particularly anything that requires knowledge of the implementation details of the random state, we accept that we will always have to handle each framework separately. |
The main difference is that import numpy as np
rng = np.random.default_rng()
print(my_function(1.0, rng))
# [ 1.25035394 -0.11511349 1.87203598 2.55088409]
print(my_function(1.0, rng))
# [3.15990919 1.07582056 1.09202392 0.33987543]
import jax
key = jax.random.PRNGKey(42)
print(my_function(1.0, key))
# [-0.5675502 1.8856307 0.06793922 3.464116 ]
print(my_function(1.0, key))
# [-0.5675502 1.8856307 0.06793922 3.464116 ] |
Yes, there is probably no universal API for situations where it matters whether or not we are dealing with a stateful RNG or a stateless key. My point is that it does not preclude us from having a universal (functional) API that can generate random numbers. In my experience, that's also the only situation where you want such an API. For everything more complicated, you need special cases anyway. Edit: FWIW, here is how I handle your specific example in my wrapper. # file: random_numpy.np
def split(rng):
return rng, rng
# file: random_jax.py
def split(key):
return random.split(key)
# in implementation, rand is rng or key
rand, subrand = random.split(rand)
my_function(1.0, subrand)
rand, subrand = random.split(rand)
my_function(1.0, subrand) |
Given lack of ecosystem agreement, we are not likely to forge a path forward for standardizing a PRNG API in the array API specification at this time. As such, I will go ahead and close this issue. We can reopen/revisit if and when we have greater community consensus. |
In conjunction with the Array API, often need some form of random number generation as well. Since this would generate Arrays and be used with Arrays the user has, there is some benefit from downstream users perspective if there is a standard API for handling random number generation.
For example in scikit-learn, the following discussion ( scikit-learn/scikit-learn#22352 (comment) ) may shed some light. A couple points that came up were whether to have a state or stateless API and NumPy's old and new APIs.
The text was updated successfully, but these errors were encountered: