Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support custom PRNG implementations #6899

Merged
merged 6 commits into from
Aug 20, 2021
Merged

support custom PRNG implementations #6899

merged 6 commits into from
Aug 20, 2021

Conversation

froystig
Copy link
Member

@froystig froystig commented Jun 4, 2021

Today jax offers a single splittable PRNG implementation (threefry 2x32) and models its key as a uint32[2]. A batch of keys is a uint32 array whose trailing dimension comprises the length-2 keys.

With this PR, we have two goals:

  1. To encapsulate the "raw array" key representation, in order to prevent unsafe or invalid key operations.
  2. To enable custom PRNG implementations with custom key representations. We want every function in the jax.random to work even when seeding, splitting, and random bit generation is defined by user code.

An added challenge of this upgrade is that the previous "raw array" key model has been around for a long time, and many downstream uses have overfit to it. To varying extents, flax, haiku, trax, and TFP all validate user-provided random keys as uint32[2] arrays. We'd like to obviate the need for such checks. Other downstream code creates keys directly with hard-coded values (including a few of our own tests), reshapes key batches (including our gamma primitive implementation), or casts other program values to keys. For safety, we would like to disallow such operations.

I'm introducing this change gradually to allow time for downstream fixes. I've tried to do this by recovering existing behavior under a flag that is, for now, on by default.

design

Again, two goals:

  1. to abstract away key representations, and
  2. to respect custom PRNGs throughout jax.random.

Item 1 is scoped only to the JAX Python level; changing key representation in the Jaxpr IR is not a goal. This PR requires that keys remain backed by uint32 arrays, now of arbitrary specified shape. We can broaden that later as needed. They will therefore appear as uint32 arrays in Jaxpr.

A custom PRNG implementation amounts to a key type K (determined by a shape) plus a handful of functions operating on such a key:

class PRNGImpl where:
  key_shape :: (int, ...)
  seed :: int[] -> K
  fold_in :: K -> int[] -> K
  split[n] :: K -> K[n]
  random_bits[shape, bit_width] :: K -> uint<bit_width>[shape]

An example, also used for unit tests introduced in this PR, is the following cartoon PRNG built from two threefry2x32 keys. To define it:

import jax
from jax import prng, numpy as jnp

threefry_seed = jax._src.prng.threefry_seed
threefry_split = jax._src.prng.threefry_split
threefry_random_bits = jax._src.prng.threefry_random_bits
threefry_fold_in = jax._src.prng.threefry_fold_in

def _double_threefry_seed(seed):
  return jnp.vstack([threefry_seed(seed + 1),
                     threefry_seed(seed + 2)])

def _double_threefry_split(key, num):
  split0 = threefry_split(key[0], num)
  split1 = threefry_split(key[1], num)
  merge = jnp.vstack([jnp.expand_dims(split0.T, axis=0),
                      jnp.expand_dims(split1.T, axis=0)])
  return merge.transpose((2, 0, 1))

def _double_threefry_random_bits(key, bit_width, shape):
  bits0 = threefry_random_bits(key[0], bit_width, shape)
  bits1 = threefry_random_bits(key[1], bit_width, shape)
  return bits0 * bits1

def _double_threefry_fold_in(key, data):
  return jnp.vstack([threefry_fold_in(key[0], data),
                     threefry_fold_in(key[1], data)])

double_threefry_prng_impl = prng.PRNGImpl(
    key_shape=(2, 2),
    seed=_double_threefry_seed,
    split=_double_threefry_split,
    random_bits=_double_threefry_random_bits,
    fold_in=_double_threefry_fold_in)

And to use it:

>>> seed_value = 73
>>> key = prng.seed_with_impl(double_threefry_prng_impl, seed_value)
>>> key
PRNGKeyArray:
  shape = ()
  impl = PRNGImpl:
           key_shape = (2, 2)
           seed = <function _double_threefry_seed at 0x7f465e44bee0>
           split = <function _double_threefry_split at 0x7f465e44bb80>
           random_bits = <function _double_threefry_random_bits at 0x7f465e44ba60>
           fold_in = <function _double_threefry_fold_in at 0x7f46227e1160>

>>> key, subkey = jax.random.split(key)
>>> jax.random.uniform(subkey)
DeviceArray(0.48681986, dtype=float32)

>>> jax.random.split(key, 10)
PRNGKeyArray:
  shape = (10,)
  impl = PRNGImpl:
           key_shape = (2, 2)
           seed = <function _double_threefry_seed at 0x7f465e44bee0>
           split = <function _double_threefry_split at 0x7f465e44bb80>
           random_bits = <function _double_threefry_random_bits at 0x7f465e44ba60>
           fold_in = <function _double_threefry_fold_in at 0x7f46227e1160>

>>> jax.vmap(lambda i: jax.random.fold_in(key, i))(jnp.arange(3))
PRNGKeyArray:
  shape = (3,)
  impl = PRNGImpl:
           key_shape = (2, 2)
           seed = <function _double_threefry_seed at 0x7f465e44bee0>
           split = <function _double_threefry_split at 0x7f465e44bb80>
           random_bits = <function _double_threefry_random_bits at 0x7f465e44ba60>
           fold_in = <function _double_threefry_fold_in at 0x7f46227e1160>

>>> key + 47
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unsupported operand type(s) for +: 'PRNGKeyArray' and 'int'

Note that the shape field of the PRNGKeyArray changes from the scalar () to (10,) when it is split 10 ways and to (3,) when it is broadcast by vmap.

Internally, this works as follows:

  • Users define a PRNG implementation impl with a key type K like the one above.
  • An internal pytree-node class PRNGKeyArray wraps the given PRNG implementation impl and lifts it to an array-like interface. This has roughly the semantics of an array of K-typed elements. Under the hood it is a uint32 array whose trailing dimensions correspond to K. It is array-like only in select ways, for instance supporting indexing and iteration but not numerical operations. Expressions such as key, subkey = jax.random.split(key) remain valid but key + 1 or jnp.exp(key) are no longer allowed.
  • The PRNGKeyArray carries the implementation impl around through jax.random functions, which then delegate to it for seed, split, random_bits, and fold_in.

Note that all new indirection happens at the Python level. This design makes no changes to Jaxpr, although the extra indirection could hopefully help make element-type changes in Jaxpr easier in the future. I discussed this with @LenaMartens and we think it will help clear the path for some Jaxpr-level RNG checking ideas in particular.

I've separately started thinking about whether to generalize the array-like wrapping concept to allow for lightweight custom-element-type arrays more broadly. If RNG keys are one special case, then arrays of unit, float0, enums, or arbitrary pytrees could be others. I mean "lightweight" both in the sense of "easy to set up" as well as "without necessarily extending Jaxpr."

temporary limitations

  • Both random.gamma, and jax.poisson are not yet implemented with custom PRNGs, since these have threefry2x32-specific implementations today that need to be generalized. I have them raise NotImplementedError when given a custom PRNG for now.

  • When fully enabled, this change breaks several of our neighboring libraries. To allow for smoother upgrades, I've retained existing behavior under a flag. Doing so introduces extra code and test complexity in the interim.

  • PRNGKeyArray may want to support more array-like methods in the future. I implemented an obvious initial handful for now.


Fixes #5081.

Relatedly addresses parts of #2294.

jax/_src/prng.py Outdated Show resolved Hide resolved
jax/_src/random.py Outdated Show resolved Hide resolved
jax/_src/prng.py Show resolved Hide resolved
jax/_src/random.py Outdated Show resolved Hide resolved
@froystig
Copy link
Member Author

froystig commented Jun 8, 2021

fyi @LenaMartens: this change prompted lifting PRNG keys to a custom type.

@froystig froystig added the pull ready Ready for copybara import and testing label Jun 9, 2021
@froystig
Copy link
Member Author

froystig commented Jun 10, 2021

Summarizing some discussion with @LenaMartens:

  • With the exception of our tests, we don't anticipate that raw keys today are manipulated or read outside of PRNG operations (e.g. in user code). Incidentally, this change helps make it a bit less likely to happen by accident, because the raw key is now a (private) attribute of the PRNGKey class.
  • This change does not defend against raw key use entirely, either. For example, the raw key can still be modified via tree_util.tree_map.
  • We think that this PR will help in setting up for future PRNG reuse checking work, by making it easy to swap the raw key attribute of a PRNG implementation (from a jnp.ndarray to whatever special key array type we design).

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Jun 11, 2021

With the exception of our tests, we don't anticipate that raw keys today are manipulated or read outside of PRNG operations (e.g. in user code). Incidentally, this change helps make it a bit less likely to happen by accident, because the raw key is now a (private) attribute of the PRNGKey class.

I'm sorry to disappoint.
When running multi-process (MPI) code with mpi4jax we sometimes need to synchronise the seed among several jax processes and sometimes we need to split a key on a process and send the splitted-keys to different processes.

The former can be implemented by only syncing the seed, but I was lazy and was synchronising the whole PRNGKey.
The latter requires access to the memory underpinning the PRNGKey and I think would be impossible to achieve otherwise?

I'd be grateful if this use-case could be considered.

@froystig
Copy link
Member Author

I'm sorry to disappoint.
When running multi-process (MPI) code with mpi4jax we sometimes need to synchronise the seed among several jax processes and sometimes we need to split a key on a process and send the splitted-keys to different processes.

Thanks for pointing these out. I think the mpi4jax use cases survive this assumption, perhaps with minor changes.

Looking at the first expression you linked:

key, _ = mpi.mpi_bcast_jax(key, root=root, comm=comm)

If I understand correctly, mpi_bcast_jax seems to bottom out in a call to mpi4jax.bcast, which in turn binds a corresponding primitive mpi_bcast_p.

The PRNG key class is registered as a pytree node in this PR, meaning that its instances can be flattened into arrays with tree_util.tree_flatten. You could modify mpi4jax.bcast to follow the common pattern in many of JAX's own primitive-binding routines: first flatten the argument to a flat list of operand arrays, then bind the primitive on the result. See lax.cond for an example. Would that work?

Looking at the second expression: to generate "a PRNGKey depending on rank number and key," could you use fold_in(key, rank), instead of what this function currently does?

@PhilipVinc
Copy link
Contributor

You could modify mpi4jax.bcast to follow the common pattern in many of JAX's own primitive-binding routines: first flatten the argument to a flat list of operand arrays, then bind the primitive on the result.

Indeed that might be a good idea, thanks. (cc @dionhaefner )

Looking at the second expression: to generate "a PRNGKey depending on rank number and key," could you use fold_in(key, rank), instead of what this function currently does?

Hmm. So you mean that instead of splitting the key into n keys, one for each node, If I have the same key on all the nodes I could fold_in the rank inside of them? yes.
However that presupposes that I can have the same key everywhere.

Still, if I can bcast PRNGKeys around that should be sufficient to implement also the second use-case without assuming that keys are identical everywhere.

Thanks!

@froystig froystig marked this pull request as ready for review June 16, 2021 15:44
@froystig froystig force-pushed the custom-rng branch 2 times, most recently from 666114f to 9cebe89 Compare June 16, 2021 16:18
@froystig
Copy link
Member Author

@PhilipVinc – thanks, and yes, that sounds right. To clarify two things:

  1. Even if a pytree flattens only into a single array, tree_flatten/unflatten also handles reconstructing Python objects with the correct types. We don't have an immediate use case for multiple-array custom PRNGs. But with several PRNG implementations (each defined by a different class somewhere), mpi4jax will need to wrap raw key arrays in the right class instance after broadcasting. That's the main motivation; incidentally, it allows for multiple arrays too.

  2. Most of JAX's basic primitives do not flatten/unflatten because they accept only a fixed number of array operands (like dot). Flattening is more common in higher-order primitives (like cond and reduce) and in some parallel collectives like psum, since these accept arbitrary operands.

Copy link

@rahulpalamuttam rahulpalamuttam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a canonical test for testing a custom prng function (either in api_test or random_test)? The intent would be to capture a small snippet describing how a user would supply a custom_rng generation function. We could even have the function just be a jax function that generates a bunch of zeros, or xors the input key with itself n number of times.

@froystig froystig force-pushed the custom-rng branch 6 times, most recently from fc7ab85 to 791158d Compare June 26, 2021 00:24
@gnecula
Copy link
Collaborator

gnecula commented Jul 1, 2021

If the design is clear, please add some notes in the PR description, so that people can then know how to review.

jax/_src/prng.py Outdated
return self.key.dtype


def make_prng_key(seed: int) -> PRNGKey:
Copy link
Contributor

@NeilGirdhar NeilGirdhar Jul 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really looking forward to this change!

Just curious, but would it make more sense for this to be a class factory?

class PRNGKey:
  @classmethod
  def from_seed(cls: Type[T], seed: int) -> T:
    return cls(_threefry_prng_key(seed))

This has two possible advantages:

  • It would be one fewer thing to expose, and
  • If you choose to add it to the abstract RNG base class, it would be part of the interface.

Also, thanks for taking care to add most of the type annotations. (There are a couple missing though, PRNGKey.__iter__ is probably going to cause MyPy errors.)

  def __iter__(self) -> Iterator['PRNGKey']:

jax/_src/prng.py Outdated
# TODO(frostig): remove if possible, otherwise declare necessary
@property
def shape(self):
return self.key.shape
Copy link
Contributor

@NeilGirdhar NeilGirdhar Jul 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you decide to keep this, then shouldn't the shape of the key be self.key.shape[:-1]? After all, this is what you've split the key into, and gives the shape of the vmapped function that this key works with. E.g., if you have vmap(f, in_args=(0,0))(rng, x), then rng.shape's leading coefficient should equal x.shape's leading coefficient. The "2" final coefficient seems like an implementation detail that a user of the class shouldn't have access to.

In my version of the Generator class, I never ended up needing this function though 😄

@froystig froystig force-pushed the custom-rng branch 2 times, most recently from 36479bd to 134a9da Compare August 15, 2021 17:41
@froystig
Copy link
Member Author

Thanks to everyone for the early comments. Previous versions of this PR were an experimental draft that bundled in several shortcuts to avoid breaking existing tests. I found time to implement the design I had in mind, together with a config flag for backwards compatibility and various warnings about future changes.

I've updated the PR description with design notes, including some remarks on temporary limitations as well.

Many early suggestions are out of date or addressed by now. I'll go back and reply to those that still stand out.

@froystig
Copy link
Member Author

@PhilipVinc - You might notice that another workaround now for turning keys into arrays for use as primitive operands is by reading key.keys instead of tree-mapping. Still, the latter is be safer and more future-proof. The keys attribute might still change.

``random_bits``, ``fold_in``).
"""

impl: PRNGImpl
Copy link
Contributor

@NeilGirdhar NeilGirdhar Aug 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to intrude, but what's the point of using the Pimpl idiom in Python? In C++, the benefits are a compilation firewall and binary compatibility—neither of which apply to Python. Wouldn't it be simpler, more idiomatic, and more efficient to just have abstract methods?

This would

  • simplify your flatten and unflatten methods by removing self.impl,
  • simplify your implementation by removing PRNGImpl,
  • simplify type annotations,
  • make it easier to have a hierarchy of behavior if the need arises,
  • and make it easier for users to define their own PRNGKeyArray.

Copy link
Member Author

@froystig froystig Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a natural question! I kicked around an idea like this earlier on.

I'm choosing to wrap-and-delegate because it directly expresses what I'm trying to enforce. PRNGKeyArray is a class made to adapt a PRNG implementation, which is expressed on an individual key, into an "array" of keys. It's an internal definition, not part of the API, and we have no need in jax for a hierarchy of behaviors around it.

If you're writing custom PRNGs from outside of jax, you're welcome to supply the PRNG interface however you like. It can be a member of whatever fun type hierarchy you've set up in your code, or just a plain record of four functions. You'll notice it need not even be a proper prng.PRNGImpl, but rather anything that duck-types as one. Hopefully that's plenty of flexibility, less for us to document, and less for anyone to learn.

To your specific points:

  • simplify your flatten and unflatten methods by removing self.impl,
  • simplify your implementation by removing PRNGImpl,

These seem like minimal savings, so I'm not concerned.

  • simplify type annotations,

How so?

  • make it easier to have a hierarchy of behavior if the need arises,

I don't foresee such a need. We can always revisit if/when it comes up.

  • and make it easier for users to define their own PRNGKeyArray.

I don't want them to!

What do you think? As someone who might use this, do these choices get in your way at all?

Copy link
Contributor

@NeilGirdhar NeilGirdhar Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the exhaustive reply.

PRNGKeyArray is a class made to adapt a PRNG implementation, which is expressed on an individual key, into an "array" of keys. It's an internal definition, not part of the API, and we have no need in jax for a hierarchy of behaviors around it.

PRNGKeyArray is a member of the union KeyArray, which is passed to the public functions in random.py. So, PRNGKeyArray is part of the API, isn't it? How should a user annotate the generator that's produced by seed_with_impl and that's passed around to user functions?

About PRNGImpl, my point is that it could just be folded into PRNGKeyArray since it doesn't appear that you use it on its own anywhere except to produce a PRNGKeyArray. I do understand your point that you want to treat PRNGKeyArray as a wrapper, but it only ever wraps one thing, and the thing it wraps is only ever used within the wrapper.

we have no need in jax for a hierarchy of behaviors around it.

Fair enough.

simplify type annotations,
How so?

The functions in PRNGImpl are annotated as Callable. It would be nice to just define them as abstract methods so that they have full signatures. I guess I don't see the point of using function-valued members that you set instead of abstract methods that you implement in subclasses. Are you avoiding inheritance for some reason?

make it easier for users to define their own PRNGKeyArray.
I don't want them to!

One possible user-defined PRNGKeyArray subclasss you might find is a hypothetical Haiku HaikuKeyArray. It would simplify Haiku code to not have to extract keys from a PRNGKeySequence, and instead use the HaikuKeyArray, which would be passable to all the JAX random functions. It would be mutable, and it would hold a PRNGKeyArray internally, which it would split, just like the PRNGKeySequence currently does. Why make it more difficult to implement such a subclass?

do these choices get in your way at all?

They do not! My real motivation is to throw away my Generator class here with a superior solution from the JAX codebase 😄 I just thought it would be nicer to have one class rather than two, and to use abstract methods rather than function-valued members. In my code, I pass around Generator everywhere. I had hoped for a similar solution like JaxGenerator or something like that 😄

Anyway, I think I've expressed my points. Thanks for the reply. I'm sure whatever you choose will work great.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the exhaustive reply.

Of course – thanks for the same. I better understand your points now. It seems like there are pros and cons either way. I went ahead with this PR as is, but we can always adapt in the future, and it will be useful to have this discussion to refer back to.

PRNGKeyArray is a member of the union KeyArray, which is passed to the public functions in random.py. So, PRNGKeyArray is part of the API, isn't it? How should a user annotate the generator that's produced by seed_with_impl and that's passed around to user functions?

You're right – the type is publicly visible. What I had in mind was: users are not meant to instantiate this class directly, nor to interact with with it other than by (a) the array-like methods (which I want no one to override) and (b) passing it to random functions. It's supposed to simply look like an array for the most part. Good correction to my unqualified remark regardless.

About PRNGImpl, my point is that it could just be folded into PRNGKeyArray since it doesn't appear that you use it on its own anywhere except to produce a PRNGKeyArray. I do understand your point that you want to treat PRNGKeyArray as a wrapper, but it only ever wraps one thing, and the thing it wraps is only ever used within the wrapper.

I see now. Yes, that would save some indirection. On the other hand, PRNG implementations would involve slightly more wrapping/unwrapping boilerplate that the class currently handles.

They do not!

Excellent. That's the most important part to me. Everything else can be revised later if/as needed.

Copy link
Contributor

@LenaMartens LenaMartens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Some small comments/request for clarification.

jax/_src/prng.py Show resolved Hide resolved
def dtype(self):
# TODO(frostig): remove after deprecation window
if config.jax_enable_custom_prng:
raise AttributeError("'PRNGKeyArray' has no attribute 'dtype'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might have already explained this to me, but why do we want to remove the ability to do key.dtype? Seems like in the current implementation this can only be unit32, regardless of the PRNGImpl.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what to return as the dtype? If the purpose is to be an array of opaque keys as elements, then uint32 seems incorrect (that's the dtype of the data buffer representation of an individual key). I figured that dtype is one of the ways in which this array type is not a numpy or jax device array. Perhaps this is too cautious, and if so I'd be happy to revisit it.

@property
def _shape(self):
base_ndim = len(self.impl.key_shape)
return self.keys.shape[:-base_ndim]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that self._shape is always equal to self.impl.key_shape? Do we ever expect them to be different?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two are different: _shape should be the shape of the encompassing "array of keys," whereas impl.key_shape is the shape of an individual key. The underlying buffer holding all the key data is an array of shape (*self._shape, *self.impl.key_shape).

jax/_src/prng.py Outdated Show resolved Hide resolved
elif _arraylike(key):
if config.jax_enable_custom_prng:
warnings.warn(
'Raw arrays as random keys to jax.random functions are deprecated. '
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this warning not trigger for every user if the default is custom_prng=False? Maybe there should be a hint on how to resolve the warning if that is the case?

More generally, do we expect users to turn on this custom_prng flag themselves? As I understand it, it's only a small group of users who rely on the keys being jnp arrays, and for most users this will be a no-op change, so I'm not sure they should be aware of the flag. I might be wrong though!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config.jax_enable_custom_prng flag defaults to false and the warning is guarded by it, so that (I think) it should only trigger when the flag is explicitly enabled.

In general, I tried to ensure that no new warnings would appear when the flag is at its default value (meaning no upgrade). I checked this visually by running random_test.py and api_test.py.

I don't expect end users to turn this flag on. I hope to work with some neighboring libraries to upgrade by turning it on, seeing some warnings (rather than errors), and making changes to handle them. Then we can deprecate the flag altogether and remove all these checks and warnings, though I will probably flip the default value of the flag to True for a while first beforehand.

Does that seem right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good, thank you for trying to do this smoothly.

Speaking for NetKet, however, we enforce semver on our dependencies so if you were to follow it too you could also avoid doing all this.

We are upper bounding the minor version of jax to 0.2, so if you were to interpret this change as a breaking change (which IMO it is) and bump the next version of jax to 0.3 instead of 0.2.X, pip would avoid installing jax 0.3 with current versions of netket, and we would have the time to upgrade our code and release a new version that is compatible with jax 0.3.

I don't know if jax is following semver or not, however.
It really does make things easier for dependent packages.

A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
Some tests check the behavior of the random bit generator---in
particular the default threefry implementation---and some check the
behavior of samplers. Separate them into different test classes.
Now that `LaxRandomTest` only tests random functions and not any
specific PRNG, it can be reused to test random functions under
different PRNGs.
…conditionally

Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 20, 2021

Nice work on this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow pluggable PRNG implementations
9 participants