From 0e305da2a50ffc5f5c52833f0dee55fe41c37579 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 8 Jun 2021 11:16:33 -0700 Subject: [PATCH] introduce an indirect PRNG key type to support customization, implement the default PRNG --- jax/_src/prng.py | 87 ++++++++++++++--------- jax/_src/random.py | 114 ++++++++++++++++++++++-------- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/prng.py | 20 ++++++ jax/random.py | 11 +-- tests/api_test.py | 5 +- tests/pmap_test.py | 3 +- tests/random_test.py | 30 ++++---- 8 files changed, 184 insertions(+), 88 deletions(-) create mode 100644 jax/prng.py diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 26018b0acae4..1f051bfb00f3 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -20,6 +20,7 @@ from jax import lax from jax import core from jax import numpy as jnp +from jax import tree_util from jax._src.api import jit from jax._src.numpy.lax_numpy import asarray from jax.lib import xla_bridge @@ -33,7 +34,50 @@ UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} -def PRNGKey(seed: int) -> jnp.ndarray: +@tree_util.register_pytree_node_class +class PRNGKey: + """Represents a PRNG key or batch thereof.""" + + key: jnp.ndarray + + def __init__(self, key: jnp.ndarray): + # key might be a dummy object due to tree_unflatten + ndim = getattr(key, 'ndim', 1) + dtype = getattr(key, 'dtype', np.uint32) + if ndim < 1 or dtype != np.uint32: + raise TypeError( + f'invalid prng key or key batch: ndim = {ndim}, dtype = {dtype}') + self.key = key + + def tree_flatten(self): + return (self.key,), None + + @classmethod + def tree_unflatten(cls, _, key): + key, = key + return cls(key) + + def fold_in(self, data: int) -> 'PRNGKey': + return PRNGKey(_fold_in(self.key, data)) + + def random_bits(self, bit_width, shape) -> jnp.ndarray: + return _random_bits(self.key, bit_width, shape) + + def split(self, num: int) -> 'PRNGKey': + return PRNGKey(_split(self.key, num)) + + def __iter__(self): + assert self.key.ndim > 0 + if self.key.ndim == 1: + raise TypeError('iteration over a single PRNG key') + return (PRNGKey(k) for k in self.key.__iter__()) + + +def make_prng_key(seed: int) -> PRNGKey: + return PRNGKey(_threefry_prng_key(seed)) + + +def _threefry_prng_key(seed: int) -> jnp.ndarray: """Create a pseudo-random number generator (PRNG) key given an integer seed. Args: @@ -68,7 +112,6 @@ def _is_prng_key(key: jnp.ndarray) -> bool: except AttributeError: return False - def _make_rotate_left(dtype): if not jnp.issubdtype(dtype, np.integer): raise TypeError("_rotate_left only accepts integer dtypes.") @@ -238,47 +281,27 @@ def threefry_2x32(keypair, count): return lax.reshape(out[:-1] if odd_size else out, count.shape) -def split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray: - """Splits a PRNG key into `num` new keys by adding a leading axis. - - Args: - key: a PRNGKey (an array with shape (2,) and dtype uint32). - num: optional, a positive integer indicating the number of keys to produce - (default 2). - - Returns: - An array with shape (num, 2) and dtype uint32 representing `num` new keys. - """ - return _split(key, int(num)) # type: ignore +def _split(key: jnp.ndarray, num: int) -> jnp.ndarray: + return _threefry_split(key, int(num)) # type: ignore @partial(jit, static_argnums=(1,)) -def _split(key, num) -> jnp.ndarray: +def _threefry_split(key, num) -> jnp.ndarray: counts = lax.iota(np.uint32, num * 2) return lax.reshape(threefry_2x32(key, counts), (num, 2)) -def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray: - """Folds in data to a PRNG key to form a new PRNG key. - - Args: - key: a PRNGKey (an array with shape (2,) and dtype uint32). - data: a 32bit integer representing data to be folded in to the key. - - Returns: - A new PRNGKey that is a deterministic function of the inputs and is - statistically safe for producing a stream of new pseudo-random values. - """ - return _fold_in(key, jnp.uint32(data)) +def _fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray: + return _threefry_fold_in(key, jnp.uint32(data)) @jit -def _fold_in(key, data): - return threefry_2x32(key, PRNGKey(data)) +def _threefry_fold_in(key, data): + return threefry_2x32(key, _threefry_prng_key(data)) @partial(jit, static_argnums=(1, 2)) -def _random_bits(key, bit_width, shape): +def _random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_prng_key(key): raise TypeError("_random_bits got invalid prng key.") @@ -291,7 +314,7 @@ def _random_bits(key, bit_width, shape): raise ValueError(f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) - key = fold_in(key, axis_index) + key = _fold_in(key, axis_index) size = prod(shape.positional) max_count = int(np.ceil(bit_width * size / 32)) @@ -299,7 +322,7 @@ def _random_bits(key, bit_width, shape): if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: - *subkeys, last_key = split(key, nblocks + 1) + *subkeys, last_key = _split(key, nblocks + 1) blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) for k in subkeys] last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) diff --git a/jax/_src/random.py b/jax/_src/random.py index 399b3ebb81b8..5336267f39e0 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -23,10 +23,10 @@ from jax import core from jax import numpy as jnp from jax._src import dtypes +from jax._src import prng from jax.core import NamedShape from jax._src.api import jit, vmap from jax._src.numpy.lax_numpy import _constant_like, _convert_and_clip_integer -from jax._src.prng import PRNGKey, fold_in, _random_bits, split, UINT_DTYPES from jax.lib import xla_bridge from jax.numpy.linalg import cholesky, svd, eigh from jax.interpreters import ad @@ -43,7 +43,7 @@ DTypeLikeInt = Any DTypeLikeFloat = Any - +UINT_DTYPES = prng.UINT_DTYPES ### utilities @@ -57,6 +57,57 @@ def _asarray(x): return jnp.asarray(x) +def _random_bits(key: prng.PRNGKey, bit_width, shape) -> jnp.ndarray: + return key.random_bits(bit_width, shape) + + +### key operations + + +def PRNGKey(seed: int) -> prng.PRNGKey: + """Create a pseudo-random number generator (PRNG) key given an integer seed. + + Args: + seed: a 64- or 32-bit integer used as the value of the key. + + Returns: + A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The + key is constructed from a 64-bit seed by effectively bit-casting to a pair + of uint32 values (or from a 32-bit seed by first padding out with zeros). + """ + # TODO: update doc + return prng.make_prng_key(seed) + + +def fold_in(key: prng.PRNGKey, data: int) -> prng.PRNGKey: + """Folds in data to a PRNG key to form a new PRNG key. + + Args: + key: a PRNGKey (an array with shape (2,) and dtype uint32). + data: a 32bit integer representing data to be folded in to the key. + + Returns: + A new PRNGKey that is a deterministic function of the inputs and is + statistically safe for producing a stream of new pseudo-random values. + """ + # TODO: update doc + return key.fold_in(jnp.uint32(data)) + + +def split(key: prng.PRNGKey, num: int = 2) -> prng.PRNGKey: + """Splits a PRNG key into `num` new keys by adding a leading axis. + + Args: + key: a PRNGKey (an array with shape (2,) and dtype uint32). + num: optional, a positive integer indicating the number of keys to produce + (default 2). + + Returns: + An array with shape (num, 2) and dtype uint32 representing `num` new keys. + """ + # TODO: update doc + return key.split(num) + ### random samplers @@ -73,7 +124,7 @@ def _check_shape(name, shape: Union[Sequence[int], NamedShape], *param_shapes): raise ValueError(msg.format(name, shape_, shape)) -def uniform(key: jnp.ndarray, +def uniform(key: prng.PRNGKey, shape: Union[Sequence[int], NamedShape] = (), dtype: DTypeLikeFloat = dtypes.float_, minval: RealArray = 0., @@ -131,7 +182,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray: lax.reshape(floats * (maxval - minval) + minval, shape.positional)) -def randint(key: jnp.ndarray, +def randint(key: prng.PRNGKey, shape: Sequence[int], minval: IntegerArray, maxval: IntegerArray, @@ -217,7 +268,7 @@ def _randint(key, shape, minval, maxval, dtype): return lax.add(minval, lax.convert_element_type(random_offset, dtype)) -def shuffle(key: jnp.ndarray, x: Array, axis: int = 0) -> jnp.ndarray: +def shuffle(key: prng.PRNGKey, x: Array, axis: int = 0) -> jnp.ndarray: """Shuffle the elements of an array uniformly at random along an axis. Args: @@ -234,7 +285,7 @@ def shuffle(key: jnp.ndarray, x: Array, axis: int = 0) -> jnp.ndarray: return _shuffle(key, x, axis) # type: ignore -def permutation(key: jnp.ndarray, x: Array) -> jnp.ndarray: +def permutation(key: prng.PRNGKey, x: Array) -> jnp.ndarray: """ Permute elements of an array along its first axis or return a permuted range. @@ -290,7 +341,7 @@ def _shuffle(key, x, axis) -> jnp.ndarray: return x -def choice(key: jnp.ndarray, +def choice(key: prng.PRNGKey, a: IntegerArray, shape: Sequence[int] = (), replace: bool = True, @@ -354,7 +405,7 @@ def choice(key: jnp.ndarray, return result.reshape(shape) -def normal(key: jnp.ndarray, +def normal(key: prng.PRNGKey, shape: Union[Sequence[int], NamedShape] = (), dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: """Sample standard normal random values with given shape and float dtype. @@ -398,7 +449,7 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray: return np.array(np.sqrt(2), dtype) * lax.erf_inv(u) -def multivariate_normal(key: jnp.ndarray, +def multivariate_normal(key: prng.PRNGKey, mean: RealArray, cov: RealArray, shape: Optional[Sequence[int]] = None, @@ -466,7 +517,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> jnp.ndarray: return mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) -def truncated_normal(key: jnp.ndarray, +def truncated_normal(key: prng.PRNGKey, lower: RealArray, upper: RealArray, shape: Optional[Union[Sequence[int], NamedShape]] = None, @@ -523,7 +574,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray: lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))) -def bernoulli(key: jnp.ndarray, +def bernoulli(key: prng.PRNGKey, p: RealArray = np.float32(0.5), shape: Optional[Union[Sequence[int], NamedShape]] = None) -> jnp.ndarray: """Sample Bernoulli random values with given shape and mean. @@ -560,7 +611,7 @@ def _bernoulli(key, p, shape) -> jnp.ndarray: return uniform(key, shape, lax.dtype(p)) < p -def beta(key: jnp.ndarray, +def beta(key: prng.PRNGKey, a: RealArray, b: RealArray, shape: Optional[Sequence[int]] = None, @@ -607,7 +658,7 @@ def _beta(key, a, b, shape, dtype): return gamma_a / (gamma_a + gamma_b) -def cauchy(key: jnp.ndarray, +def cauchy(key: prng.PRNGKey, shape: Sequence[int] = (), dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: """Sample Cauchy random values with given shape and float dtype. @@ -637,7 +688,7 @@ def _cauchy(key, shape, dtype): return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5)))) -def dirichlet(key: jnp.ndarray, +def dirichlet(key: prng.PRNGKey, alpha: RealArray, shape: Optional[Sequence[int]] = None, dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: @@ -684,7 +735,7 @@ def _dirichlet(key, alpha, shape, dtype): return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True) -def exponential(key: jnp.ndarray, +def exponential(key: prng.PRNGKey, shape: Sequence[int] = (), dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: """Sample Exponential random values with given shape and float dtype. @@ -783,9 +834,9 @@ def _gamma_impl(key, a, use_vmap=False): a_shape = jnp.shape(a) # split key to match the shape of a key_ndim = jnp.ndim(key) - 1 - key = jnp.reshape(key, (-1, 2)) + key = prng.PRNGKey(jnp.reshape(key, (-1, 2))) key = vmap(split, in_axes=(0, None))(key, prod(a_shape[key_ndim:])) - keys = jnp.reshape(key, (-1, 2)) + keys = prng.PRNGKey(jnp.reshape(key.key, (-1, 2))) alphas = jnp.reshape(a, -1) if use_vmap: samples = vmap(_gamma_one)(keys, alphas) @@ -814,7 +865,7 @@ def _gamma_batching_rule(batched_args, batch_dims): multiple_results=False) batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule -def gamma(key: jnp.ndarray, +def gamma(key: prng.PRNGKey, a: RealArray, shape: Optional[Sequence[int]] = None, dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: @@ -834,13 +885,16 @@ def gamma(key: jnp.ndarray, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``a.shape``. """ + if not isinstance(key, prng.PRNGKey): + raise NotImplementedError( + f'`gamma` only supported for the default PRNG, not f{type(key)}') if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) - return _gamma(key, a, shape, dtype) + return _gamma(key.key, a, shape, dtype) @partial(jit, static_argnums=(2, 3)) def _gamma(key, a, shape, dtype): @@ -942,7 +996,7 @@ def _poisson(key, lam, shape, dtype): return lax.select(lam == 0, jnp.zeros_like(result), result) -def poisson(key: jnp.ndarray, +def poisson(key: prng.PRNGKey, lam: RealArray, shape: Sequence[int] = (), dtype: DTypeLikeInt = dtypes.int_) -> jnp.ndarray: @@ -967,7 +1021,7 @@ def poisson(key: jnp.ndarray, return _poisson(key, lam, shape, dtype) -def gumbel(key: jnp.ndarray, +def gumbel(key: prng.PRNGKey, shape: Sequence[int] = (), dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: """Sample Gumbel random values with given shape and float dtype. @@ -996,7 +1050,7 @@ def _gumbel(key, shape, dtype): uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) -def categorical(key: jnp.ndarray, +def categorical(key: prng.PRNGKey, logits: RealArray, axis: int = -1, shape: Optional[Sequence[int]] = None) -> jnp.ndarray: @@ -1030,7 +1084,7 @@ def categorical(key: jnp.ndarray, return jnp.argmax(gumbel(key, sample_shape + logits.shape, logits.dtype) + logits, axis=axis) -def laplace(key: jnp.ndarray, +def laplace(key: prng.PRNGKey, shape: Sequence[int] = (), dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: """Sample Laplace random values with given shape and float dtype. @@ -1060,7 +1114,7 @@ def _laplace(key, shape, dtype): return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u)))) -def logistic(key: jnp.ndarray, +def logistic(key: prng.PRNGKey, shape: Sequence[int] = (), dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: """Sample logistic random values with given shape and float dtype. @@ -1089,7 +1143,7 @@ def _logistic(key, shape, dtype): return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x))) -def pareto(key: jnp.ndarray, +def pareto(key: prng.PRNGKey, b: RealArray, shape: Optional[Sequence[int]] = None, dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: @@ -1129,7 +1183,7 @@ def _pareto(key, b, shape, dtype): return lax.exp(e / b) -def t(key: jnp.ndarray, +def t(key: prng.PRNGKey, df: RealArray, shape: Sequence[int] = (), dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: @@ -1172,7 +1226,7 @@ def _t(key, df, shape, dtype): return n * jnp.sqrt(half_df / g) -def rademacher(key: jnp.ndarray, +def rademacher(key: prng.PRNGKey, shape: Sequence[int], dtype: DTypeLikeInt = dtypes.int_) -> jnp.ndarray: """Sample from a Rademacher distribution. @@ -1198,7 +1252,7 @@ def _rademacher(key, shape, dtype): return (2 * bernoulli_samples - 1).astype(dtype) -def maxwell(key: jnp.ndarray, +def maxwell(key: prng.PRNGKey, shape: Sequence[int] = (), dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: """Sample from a one sided Maxwell distribution. @@ -1231,7 +1285,7 @@ def _maxwell(key, shape, dtype): return jnp.linalg.norm(norm_rvs, axis=-1) -def double_sided_maxwell(key: jnp.ndarray, +def double_sided_maxwell(key: prng.PRNGKey, loc: RealArray, scale: RealArray, shape: Sequence[int] = (), @@ -1276,7 +1330,7 @@ def _double_sided_maxwell(key, loc, scale, shape, dtype): return random_sign * maxwell_rvs * scale + loc -def weibull_min(key: jnp.ndarray, +def weibull_min(key: prng.PRNGKey, scale: RealArray, concentration: RealArray, shape: Sequence[int] = (), diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 2f83224cec80..1a7b2a031060 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2016,7 +2016,7 @@ def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval): return res -tf_impl_with_avals[jax.random.threefry2x32_p] = _threefry2x32_jax_impl +tf_impl_with_avals[jax.prng.threefry2x32_p] = _threefry2x32_jax_impl # Use the vmap implementation, otherwise on TPU the performance is really bad # With use_vmap=True on, we get about the same performance for JAX and jax2tf. diff --git a/jax/prng.py b/jax/prng.py new file mode 100644 index 000000000000..dcb79b45cc37 --- /dev/null +++ b/jax/prng.py @@ -0,0 +1,20 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: F401 + +from jax._src.prng import ( + threefry2x32_p, + threefry_2x32, +) diff --git a/jax/random.py b/jax/random.py index 369679564755..316b8bd039dc 100644 --- a/jax/random.py +++ b/jax/random.py @@ -78,15 +78,8 @@ # flake8: noqa: F401 -from jax._src.prng import ( - PRNGKey, - fold_in, - split, - threefry2x32_p, - threefry_2x32, -) - from jax._src.random import ( + PRNGKey, bernoulli, beta, categorical, @@ -95,6 +88,7 @@ dirichlet, double_sided_maxwell, exponential, + fold_in, gamma, gumbel, laplace, @@ -109,6 +103,7 @@ randint, random_gamma_p, shuffle, + split, t, truncated_normal, uniform, diff --git a/tests/api_test.py b/tests/api_test.py index af5abe709b7d..9bd0c916af47 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -512,10 +512,11 @@ def init(): key_list[0] = key return jax.random.normal(subkey, ()) - key_list[0] = np.array([2384771982, 3928867769], dtype=np.uint32) + key_list[0] = jax._src.prng.PRNGKey( + np.array([2384771982, 3928867769], dtype=np.uint32)) init() self.jit(init)() - self.assertIsInstance(key_list[0], core.Tracer) + self.assertIsInstance(key_list[0].key, core.Tracer) def test_jit_wrapped_attributes(self): def f(x: int) -> int: diff --git a/tests/pmap_test.py b/tests/pmap_test.py index faae79413660..024dc705760c 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1048,7 +1048,8 @@ def g(key): @vmap def s(keys): - keys = jnp.broadcast_to(keys, (N_DEVICES,) + keys.shape) + keys = jax._src.prng.PRNGKey( + jnp.broadcast_to(keys.key, (N_DEVICES,) + keys.key.shape)) return g(keys) ans = s(keys) # doesn't crash diff --git a/tests/random_test.py b/tests/random_test.py index c3c64a6a04ca..2d0ee1fa513d 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -30,6 +30,7 @@ from jax import grad from jax import lax from jax import numpy as jnp +from jax import prng from jax import random from jax import test_util as jtu from jax import vmap @@ -94,23 +95,23 @@ def result_to_hex(result): return tuple([hex(x.copy()).rstrip("L") for x in result]) expected = ("0x6b200159", "0x99ba4efe") - result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0])) + result = prng.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0])) self.assertEqual(expected, result_to_hex(result)) expected = ("0x1cb996fc", "0xbb002be7") - result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1])) + result = prng.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1])) self.assertEqual(expected, result_to_hex(result)) expected = ("0xc4923a9c", "0x483df7a0") - result = random.threefry_2x32( + result = prng.threefry_2x32( np.uint32([0x13198a2e, 0x03707344]), np.uint32([0x243f6a88, 0x85a308d3])) self.assertEqual(expected, result_to_hex(result)) def testThreefry2x32Large(self): n = 10000000 - result = random.threefry_2x32( + result = prng.threefry_2x32( (np.uint32(0x13198a2e), np.uint32(0x03707344)), jnp.concatenate([ jnp.full((n,), 0x243f6a88, jnp.uint32), @@ -122,7 +123,7 @@ def testThreefry2x32Large(self): def testThreefry2x32Empty(self): # Regression test for an op-by-op crash for empty arrays in CUDA mode. with api.disable_jit(): - result = random.threefry_2x32( + result = prng.threefry_2x32( (np.uint32(0x13198a2e), np.uint32(0x03707344)), jnp.ones((10, 0,), jnp.uint32)) np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32)) @@ -721,13 +722,13 @@ def testIssue222(self): def testFoldIn(self): key = random.PRNGKey(0) - keys = [random.fold_in(key, i) for i in range(10)] + keys = [random.fold_in(key, i).key for i in range(10)] assert np.unique(np.ravel(keys)).shape == (20,) def testFoldInBig(self): key = random.PRNGKey(0) seeds = [2 ** 32 - 2, 2 ** 32 - 1] - keys = [random.fold_in(key, seed) for seed in seeds] + keys = [random.fold_in(key, seed).key for seed in seeds] assert np.unique(np.ravel(keys)).shape == (4,) def testStaticShapeErrors(self): @@ -765,7 +766,7 @@ def testNoOpByOpUnderHash(self): def fail(*args, **kwargs): assert False apply_primitive, xla.apply_primitive = xla.apply_primitive, fail try: - _ = random.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32)) + _ = prng.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32)) finally: xla.apply_primitive = apply_primitive @@ -787,14 +788,14 @@ def testPRNGValues(self): [6, 3, 4]], dtype='int32')) self.assertAllClose( - random.split(k, 4), + random.split(k, 4).key, np.array([[2285895361, 1501764800], [1518642379, 4090693311], [ 433833334, 4221794875], [ 839183663, 3740430601]], dtype='uint32')) self.assertAllClose( - random.fold_in(k, 4), + random.fold_in(k, 4).key, np.array([2285895361, 433833334], dtype='uint32')) def testDtypeErrorMessage(self): @@ -961,9 +962,9 @@ def test_prng_seeds_and_keys(self, seed, type, jit, key): self.skipTest("Expected failure: integer out of range for jit.") seed = type(seed) if jit: - actual = api.jit(random.PRNGKey)(seed) + actual = api.jit(random.PRNGKey)(seed).key else: - actual = random.PRNGKey(seed) + actual = random.PRNGKey(seed).key expected = jnp.array(key, dtype=jnp.uint32) self.assertArraysEqual(actual, expected) @@ -978,7 +979,8 @@ def test_prng_jit_invariance(self, seed, type): self.skipTest("Expected failure: Python int too large.") type = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type] args_maker = lambda: [type(seed)] - self._CompileAndCheck(random.PRNGKey, args_maker) + make_prng = lambda seed: random.PRNGKey(seed).key + self._CompileAndCheck(make_prng, args_maker) def test_prng_errors(self): seed = np.iinfo(np.int64).max + 1 @@ -988,7 +990,7 @@ def test_prng_errors(self): api.jit(random.PRNGKey)(seed) def test_random_split_doesnt_device_put_during_tracing(self): - key = random.PRNGKey(1).block_until_ready() + key = random.PRNGKey(1).key.block_until_ready() with jtu.count_device_put() as count: api.jit(random.split)(key) self.assertEqual(count[0], 1) # 1 for the argument device_put