diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 26018b0acae4..66aaef789633 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -14,6 +14,7 @@ from functools import partial +from typing import Any import numpy as np @@ -30,31 +31,23 @@ from jax._src.util import prod +PRNG = Any UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} -def PRNGKey(seed: int) -> jnp.ndarray: - """Create a pseudo-random number generator (PRNG) key given an integer seed. - - Args: - seed: a 64- or 32-bit integer used as the value of the key. - - Returns: - A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The - key is constructed from a 64-bit seed by effectively bit-casting to a pair - of uint32 values (or from a 32-bit seed by first padding out with zeros). - """ +def threefry_init(seed: int) -> jnp.ndarray: # Avoid overflowerror in X32 mode by first converting ints to int64. - # This breaks JIT invariance of PRNGKey for large ints, but supports the - # common use-case of instantiating PRNGKey with Python hashes in X32 mode. + # This breaks JIT invariance of this init function for large ints, + # but supports the common use-case of calling it with Python hashes + # in X32 mode. if isinstance(seed, int): seed_arr = jnp.asarray(np.int64(seed)) else: seed_arr = jnp.asarray(seed) if seed_arr.shape: - raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.") + raise TypeError(f"PRNG seed must be a scalar; got {seed!r}.") if not np.issubdtype(seed_arr.dtype, np.integer): - raise TypeError(f"PRNGKey seed must be an integer; got {seed!r}") + raise TypeError(f"PRNG seed must be an integer; got {seed!r}") convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32))) @@ -238,50 +231,30 @@ def threefry_2x32(keypair, count): return lax.reshape(out[:-1] if odd_size else out, count.shape) -def split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray: - """Splits a PRNG key into `num` new keys by adding a leading axis. - - Args: - key: a PRNGKey (an array with shape (2,) and dtype uint32). - num: optional, a positive integer indicating the number of keys to produce - (default 2). - - Returns: - An array with shape (num, 2) and dtype uint32 representing `num` new keys. - """ - return _split(key, int(num)) # type: ignore +def threefry_split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray: + return _threefry_split(key, int(num)) # type: ignore @partial(jit, static_argnums=(1,)) -def _split(key, num) -> jnp.ndarray: +def _threefry_split(key, num) -> jnp.ndarray: counts = lax.iota(np.uint32, num * 2) return lax.reshape(threefry_2x32(key, counts), (num, 2)) -def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray: - """Folds in data to a PRNG key to form a new PRNG key. - - Args: - key: a PRNGKey (an array with shape (2,) and dtype uint32). - data: a 32bit integer representing data to be folded in to the key. - - Returns: - A new PRNGKey that is a deterministic function of the inputs and is - statistically safe for producing a stream of new pseudo-random values. - """ - return _fold_in(key, jnp.uint32(data)) +def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray: + return _threefry_fold_in(key, jnp.uint32(data)) @jit -def _fold_in(key, data): - return threefry_2x32(key, PRNGKey(data)) +def _threefry_fold_in(key, data): + return threefry_2x32(key, threefry_init(data)) @partial(jit, static_argnums=(1, 2)) -def _random_bits(key, bit_width, shape): +def threefry_random_bits(key, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_prng_key(key): - raise TypeError("_random_bits got invalid prng key.") + raise TypeError("random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) @@ -291,7 +264,7 @@ def _random_bits(key, bit_width, shape): raise ValueError(f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) - key = fold_in(key, axis_index) + key = threefry_fold_in(key, axis_index) size = prod(shape.positional) max_count = int(np.ceil(bit_width * size / 32)) @@ -299,7 +272,7 @@ def _random_bits(key, bit_width, shape): if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: - *subkeys, last_key = split(key, nblocks + 1) + *subkeys, last_key = threefry_split(key, nblocks + 1) blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) for k in subkeys] last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) @@ -324,3 +297,10 @@ def _random_bits(key, bit_width, shape): bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width),), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape) + + +class threefry_prng: + init = threefry_init + fold_in = threefry_fold_in + random_bits = threefry_random_bits + split = threefry_split diff --git a/jax/_src/random.py b/jax/_src/random.py index 399b3ebb81b8..f18fe549f76c 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -26,7 +26,7 @@ from jax.core import NamedShape from jax._src.api import jit, vmap from jax._src.numpy.lax_numpy import _constant_like, _convert_and_clip_integer -from jax._src.prng import PRNGKey, fold_in, _random_bits, split, UINT_DTYPES +from jax._src.prng import PRNG, UINT_DTYPES, threefry_prng from jax.lib import xla_bridge from jax.numpy.linalg import cholesky, svd, eigh from jax.interpreters import ad @@ -43,7 +43,7 @@ DTypeLikeInt = Any DTypeLikeFloat = Any - +default_prng = threefry_prng ### utilities @@ -56,6 +56,52 @@ def _asarray(x): raise TypeError(f"Function requires array input, got {x} of type {type(x)}.") return jnp.asarray(x) +def _random_bits(key, bit_width, shape, prng: PRNG): + return prng.random_bits(key, bit_width, shape) + + +### prng key operations + +def PRNGKey(seed: int, prng: PRNG = default_prng) -> jnp.ndarray: + """Create a pseudo-random number generator (PRNG) key given an integer seed. + + Args: + seed: a 64- or 32-bit integer used as the value of the key. + + Returns: + A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The + key is constructed from a 64-bit seed by effectively bit-casting to a pair + of uint32 values (or from a 32-bit seed by first padding out with zeros). + """ + return prng.init(seed) + +def split(key: jnp.ndarray, num: int = 2, + prng: PRNG = default_prng) -> jnp.ndarray: + """Splits a PRNG key into `num` new keys by adding a leading axis. + + Args: + key: a PRNGKey (an array with shape (2,) and dtype uint32). + num: optional, a positive integer indicating the number of keys to produce + (default 2). + + Returns: + An array with shape (num, 2) and dtype uint32 representing `num` new keys. + """ + return prng.split(key, num) + +def fold_in(key: jnp.ndarray, data: int, + prng: PRNG = default_prng) -> jnp.ndarray: + """Folds in data to a PRNG key to form a new PRNG key. + + Args: + key: a PRNGKey (an array with shape (2,) and dtype uint32). + data: a 32bit integer representing data to be folded in to the key. + + Returns: + A new PRNGKey that is a deterministic function of the inputs and is + statistically safe for producing a stream of new pseudo-random values. + """ + return prng.fold_in(key, data) ### random samplers @@ -77,7 +123,8 @@ def uniform(key: jnp.ndarray, shape: Union[Sequence[int], NamedShape] = (), dtype: DTypeLikeFloat = dtypes.float_, minval: RealArray = 0., - maxval: RealArray = 1.) -> jnp.ndarray: + maxval: RealArray = 1., + prng: PRNG = default_prng) -> jnp.ndarray: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -97,10 +144,10 @@ def uniform(key: jnp.ndarray, f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.as_named_shape(shape) - return _uniform(key, shape, dtype, minval, maxval) # type: ignore + return _uniform(key, shape, dtype, minval, maxval, prng) # type: ignore -@partial(jit, static_argnums=(1, 2)) -def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray: +@partial(jit, static_argnums=(1, 2, 5)) +def _uniform(key, shape, dtype, minval, maxval, prng) -> jnp.ndarray: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): raise TypeError("uniform only accepts floating point dtypes.") @@ -116,7 +163,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray: if nbits not in (16, 32, 64): raise TypeError("uniform only accepts 32- or 64-bit dtypes.") - bits = _random_bits(key, nbits, shape) + bits = _random_bits(key, nbits, shape, prng=prng) # The strategy here is to randomize only the mantissa bits with an exponent of # 1 (after applying the bias), then shift and scale to the desired range. The @@ -135,7 +182,8 @@ def randint(key: jnp.ndarray, shape: Sequence[int], minval: IntegerArray, maxval: IntegerArray, - dtype: DTypeLikeInt = dtypes.int_): + dtype: DTypeLikeInt = dtypes.int_, + prng: PRNG = default_prng): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -153,10 +201,10 @@ def randint(key: jnp.ndarray, """ dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _randint(key, shape, minval, maxval, dtype) + return _randint(key, shape, minval, maxval, dtype, prng) -@partial(jit, static_argnums=(1, 4)) -def _randint(key, shape, minval, maxval, dtype): +@partial(jit, static_argnums=(1, 4, 5)) +def _randint(key, shape, minval, maxval, dtype, prng): _check_shape("randint", shape, np.shape(minval), np.shape(maxval)) if not jnp.issubdtype(dtype, np.integer): raise TypeError(f"randint only accepts integer dtypes, got {dtype}") @@ -171,7 +219,8 @@ def _randint(key, shape, minval, maxval, dtype): # Flag where maxval is greater than the maximum value of dtype # in order to handle cases like randint(key, shape, 0, 256, 'uint8') maxval_out_of_range = lax.gt( - maxval, _convert_and_clip_integer(jnp.array(jnp.iinfo(dtype).max, dtype), maxval.dtype)) + maxval, _convert_and_clip_integer(jnp.array(jnp.iinfo(dtype).max, dtype), + maxval.dtype)) minval = _convert_and_clip_integer(minval, dtype) maxval = _convert_and_clip_integer(maxval, dtype) @@ -185,8 +234,8 @@ def _randint(key, shape, minval, maxval, dtype): # This algorithm is biased whenever (maxval - minval) is not a power of 2. # We generate double the number of random bits required by the dtype so as to # reduce that bias. - k1, k2 = split(key) - rbits = lambda key: _random_bits(key, nbits, shape) + k1, k2 = split(key, prng=prng) + rbits = lambda key: _random_bits(key, nbits, shape, prng=prng) higher_bits, lower_bits = rbits(k1), rbits(k2) unsigned_dtype = UINT_DTYPES[nbits] @@ -217,7 +266,8 @@ def _randint(key, shape, minval, maxval, dtype): return lax.add(minval, lax.convert_element_type(random_offset, dtype)) -def shuffle(key: jnp.ndarray, x: Array, axis: int = 0) -> jnp.ndarray: +def shuffle(key: jnp.ndarray, x: Array, axis: int = 0, + prng: PRNG = default_prng) -> jnp.ndarray: """Shuffle the elements of an array uniformly at random along an axis. Args: @@ -231,10 +281,11 @@ def shuffle(key: jnp.ndarray, x: Array, axis: int = 0) -> jnp.ndarray: msg = ("jax.random.shuffle is deprecated and will be removed in a future release. " "Use jax.random.permutation") warnings.warn(msg, FutureWarning) - return _shuffle(key, x, axis) # type: ignore + return _shuffle(key, x, axis, prng) # type: ignore -def permutation(key: jnp.ndarray, x: Array) -> jnp.ndarray: +def permutation(key: jnp.ndarray, x: Array, + prng: PRNG = default_prng) -> jnp.ndarray: """ Permute elements of an array along its first axis or return a permuted range. @@ -253,17 +304,17 @@ def permutation(key: jnp.ndarray, x: Array) -> jnp.ndarray: if not np.issubdtype(lax.dtype(x), np.integer): raise TypeError("x must be an integer or at least 1-dimensional") x = int(x) # type: ignore[assignment] - return _shuffle(key, jnp.arange(x), 0) + return _shuffle(key, jnp.arange(x), 0, prng) elif np.ndim(x) == 1: - return _shuffle(key, x, 0) + return _shuffle(key, x, 0, prng) else: assert isinstance(x, jnp.ndarray) - ind = _shuffle(key, jnp.arange(x.shape[0]), 0) # type: ignore[attribute-error] + ind = _shuffle(key, jnp.arange(x.shape[0]), 0, prng) # type: ignore[attribute-error] return x[ind] -@partial(jit, static_argnums=(2,)) -def _shuffle(key, x, axis) -> jnp.ndarray: +@partial(jit, static_argnums=(2, 3)) +def _shuffle(key, x, axis, prng: PRNG) -> jnp.ndarray: # On parallel architectures, Fisher-Yates is more expensive than doing # multiple sorts. This algorithm is based on one developed and analyzed by # tjablin@. We sort according to randomly-generated 32bit keys, but those keys @@ -283,8 +334,8 @@ def _shuffle(key, x, axis) -> jnp.ndarray: num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max))) for _ in range(num_rounds): - key, subkey = split(key) - sort_keys = _random_bits(subkey, 32, x.shape) + key, subkey = split(key, prng=prng) + sort_keys = _random_bits(subkey, 32, x.shape, prng=prng) _, x = lax.sort_key_val(sort_keys, x, axis) return x @@ -294,7 +345,8 @@ def choice(key: jnp.ndarray, a: IntegerArray, shape: Sequence[int] = (), replace: bool = True, - p=None) -> jnp.ndarray: + p=None, + prng: PRNG = default_prng) -> jnp.ndarray: """Generates a random sample from a given 1-D array. Args: @@ -334,10 +386,10 @@ def choice(key: jnp.ndarray, if p is None: if replace: - ind = randint(key, shape, 0, n_inputs) + ind = randint(key, shape, 0, n_inputs, prng=prng) result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index] else: - result = permutation(key, a)[:n_draws] + result = permutation(key, a, prng=prng)[:n_draws] else: if p.shape != (n_inputs,): raise ValueError("p must be None or match the shape of a") @@ -348,7 +400,7 @@ def choice(key: jnp.ndarray, result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index] else: # Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/ - g = -gumbel(key, (n_inputs,)) - jnp.log(p) + g = -gumbel(key, (n_inputs,), prng=prng) - jnp.log(p) ind = jnp.argsort(g)[:n_draws] result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index] return result.reshape(shape) @@ -356,7 +408,8 @@ def choice(key: jnp.ndarray, def normal(key: jnp.ndarray, shape: Union[Sequence[int], NamedShape] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample standard normal random values with given shape and float dtype. Args: @@ -374,27 +427,27 @@ def normal(key: jnp.ndarray, f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.as_named_shape(shape) - return _normal(key, shape, dtype) # type: ignore + return _normal(key, shape, dtype, prng) # type: ignore -@partial(jit, static_argnums=(1, 2)) -def _normal(key, shape, dtype) -> jnp.ndarray: +@partial(jit, static_argnums=(1, 2, 3)) +def _normal(key, shape, dtype, prng: PRNG) -> jnp.ndarray: if dtypes.issubdtype(dtype, np.complexfloating): sqrt2 = np.array(np.sqrt(2), dtype) - key_re, key_im = split(key) + key_re, key_im = split(key, prng=prng) real_dtype = np.array(0, dtype).real.dtype - _re = _normal_real(key_re, shape, real_dtype) - _im = _normal_real(key_im, shape, real_dtype) + _re = _normal_real(key_re, shape, real_dtype, prng) + _im = _normal_real(key_im, shape, real_dtype, prng) return (_re + 1j * _im) / sqrt2 else: - return _normal_real(key, shape, dtype) # type: ignore + return _normal_real(key, shape, dtype, prng) # type: ignore -@partial(jit, static_argnums=(1, 2)) -def _normal_real(key, shape, dtype) -> jnp.ndarray: +@partial(jit, static_argnums=(1, 2, 3)) +def _normal_real(key, shape, dtype, prng: PRNG) -> jnp.ndarray: _check_shape("normal", shape) lo = np.nextafter(np.array(-1., dtype), 0., dtype=dtype) hi = np.array(1., dtype) - u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type] + u = uniform(key, shape, dtype, lo, hi, prng=prng) # type: ignore[arg-type] return np.array(np.sqrt(2), dtype) * lax.erf_inv(u) @@ -403,7 +456,8 @@ def multivariate_normal(key: jnp.ndarray, cov: RealArray, shape: Optional[Sequence[int]] = None, dtype: DTypeLikeFloat = dtypes.float_, - method: str = 'cholesky') -> jnp.ndarray: + method: str = 'cholesky', + prng: PRNG = default_prng) -> jnp.ndarray: """Sample multivariate normal random values with given mean and covariance. Args: @@ -433,10 +487,12 @@ def multivariate_normal(key: jnp.ndarray, dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) - return _multivariate_normal(key, mean, cov, shape, dtype, method) # type: ignore + return _multivariate_normal( + key, mean, cov, shape, dtype, method, prng) # type: ignore -@partial(jit, static_argnums=(3, 4, 5)) -def _multivariate_normal(key, mean, cov, shape, dtype, method) -> jnp.ndarray: +@partial(jit, static_argnums=(3, 4, 5, 6)) +def _multivariate_normal(key, mean, cov, shape, dtype, method, + prng: PRNG) -> jnp.ndarray: if not np.ndim(mean) >= 1: msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}" raise ValueError(msg.format(np.ndim(mean))) @@ -462,7 +518,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> jnp.ndarray: factor = v * jnp.sqrt(w) else: # 'cholesky' factor = cholesky(cov) - normal_samples = normal(key, shape + mean.shape[-1:], dtype) + normal_samples = normal(key, shape + mean.shape[-1:], dtype, prng=prng) return mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) @@ -470,7 +526,8 @@ def truncated_normal(key: jnp.ndarray, lower: RealArray, upper: RealArray, shape: Optional[Union[Sequence[int], NamedShape]] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample truncated standard normal random values with given shape and dtype. Args: @@ -497,10 +554,12 @@ def truncated_normal(key: jnp.ndarray, dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.as_named_shape(shape) - return _truncated_normal(key, lower, upper, shape, dtype) # type: ignore + return _truncated_normal( + key, lower, upper, shape, dtype, prng) # type: ignore -@partial(jit, static_argnums=(3, 4)) -def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray: +@partial(jit, static_argnums=(3, 4, 5)) +def _truncated_normal(key, lower, upper, shape, dtype, + prng: PRNG) -> jnp.ndarray: if shape is None: shape = lax.broadcast_shapes(np.shape(lower), np.shape(upper)) else: @@ -513,7 +572,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray: b = lax.erf(upper / sqrt2) if not jnp.issubdtype(dtype, np.floating): raise TypeError("truncated_normal only accepts floating point dtypes.") - u = uniform(key, shape, dtype, minval=a, maxval=b) + u = uniform(key, shape, dtype, minval=a, maxval=b, prng=prng) out = sqrt2 * lax.erf_inv(u) # Clamp the value to the open interval (lower, upper) to make sure that # rounding (or if we chose `a` for `u`) doesn't push us outside of the range. @@ -525,7 +584,8 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray: def bernoulli(key: jnp.ndarray, p: RealArray = np.float32(0.5), - shape: Optional[Union[Sequence[int], NamedShape]] = None) -> jnp.ndarray: + shape: Optional[Union[Sequence[int], NamedShape]] = None, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Bernoulli random values with given shape and mean. Args: @@ -547,24 +607,25 @@ def bernoulli(key: jnp.ndarray, msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) p = lax.convert_element_type(p, dtype) - return _bernoulli(key, p, shape) # type: ignore + return _bernoulli(key, p, shape, prng) # type: ignore -@partial(jit, static_argnums=(2,)) -def _bernoulli(key, p, shape) -> jnp.ndarray: +@partial(jit, static_argnums=(2, 3)) +def _bernoulli(key, p, shape, prng: PRNG) -> jnp.ndarray: if shape is None: # TODO: Use the named part of `p` as well shape = np.shape(p) else: _check_shape("bernoulli", shape, np.shape(p)) - return uniform(key, shape, lax.dtype(p)) < p + return uniform(key, shape, lax.dtype(p), prng=prng) < p def beta(key: jnp.ndarray, a: RealArray, b: RealArray, shape: Optional[Sequence[int]] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Beta random values with given shape and float dtype. Args: @@ -589,9 +650,9 @@ def beta(key: jnp.ndarray, dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) - return _beta(key, a, b, shape, dtype) + return _beta(key, a, b, shape, dtype, prng) -def _beta(key, a, b, shape, dtype): +def _beta(key, a, b, shape, dtype, prng: PRNG): if shape is None: shape = lax.broadcast_shapes(np.shape(a), np.shape(b)) else: @@ -599,17 +660,18 @@ def _beta(key, a, b, shape, dtype): a = lax.convert_element_type(a, dtype) b = lax.convert_element_type(b, dtype) - key_a, key_b = split(key) + key_a, key_b = split(key, prng=prng) a = jnp.broadcast_to(a, shape) b = jnp.broadcast_to(b, shape) - gamma_a = gamma(key_a, a, shape, dtype) - gamma_b = gamma(key_b, b, shape, dtype) + gamma_a = gamma(key_a, a, shape, dtype, prng=prng) + gamma_b = gamma(key_b, b, shape, dtype, prng=prng) return gamma_a / (gamma_a + gamma_b) def cauchy(key: jnp.ndarray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Cauchy random values with given shape and float dtype. Args: @@ -627,12 +689,13 @@ def cauchy(key: jnp.ndarray, f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _cauchy(key, shape, dtype) + return _cauchy(key, shape, dtype, prng) -@partial(jit, static_argnums=(1, 2)) -def _cauchy(key, shape, dtype): +@partial(jit, static_argnums=(1, 2, 3)) +def _cauchy(key, shape, dtype, prng: PRNG): _check_shape("cauchy", shape) - u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.) + u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1., + prng=prng) pi = _constant_like(u, np.pi) return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5)))) @@ -640,7 +703,8 @@ def _cauchy(key, shape, dtype): def dirichlet(key: jnp.ndarray, alpha: RealArray, shape: Optional[Sequence[int]] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Dirichlet random values with given shape and float dtype. Args: @@ -666,10 +730,10 @@ def dirichlet(key: jnp.ndarray, dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) - return _dirichlet(key, alpha, shape, dtype) + return _dirichlet(key, alpha, shape, dtype, prng) -@partial(jit, static_argnums=(2, 3)) -def _dirichlet(key, alpha, shape, dtype): +@partial(jit, static_argnums=(2, 3, 4)) +def _dirichlet(key, alpha, shape, dtype, prng: PRNG): if not np.ndim(alpha) >= 1: msg = "dirichlet requires alpha.ndim >= 1, got alpha.ndim == {}" raise ValueError(msg.format(np.ndim(alpha))) @@ -680,13 +744,15 @@ def _dirichlet(key, alpha, shape, dtype): _check_shape("dirichlet", shape, np.shape(alpha)[:-1]) alpha = lax.convert_element_type(alpha, dtype) - gamma_samples = gamma(key, alpha, shape + np.shape(alpha)[-1:], dtype) + gamma_samples = gamma(key, alpha, shape + np.shape(alpha)[-1:], dtype, + prng=prng) return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True) def exponential(key: jnp.ndarray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Exponential random values with given shape and float dtype. Args: @@ -704,12 +770,12 @@ def exponential(key: jnp.ndarray, f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _exponential(key, shape, dtype) + return _exponential(key, shape, dtype, prng) -@partial(jit, static_argnums=(1, 2)) -def _exponential(key, shape, dtype): +@partial(jit, static_argnums=(1, 2, 3)) +def _exponential(key, shape, dtype, prng: PRNG): _check_shape("exponential", shape) - u = uniform(key, shape, dtype) + u = uniform(key, shape, dtype, prng=prng) # taking 1 - u to move the domain of log to (0, 1] instead of [0, 1) return lax.neg(lax.log1p(lax.neg(u))) @@ -731,7 +797,8 @@ def _gamma_one(key, alpha): # Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha) boost = lax.select(lax.ge(alpha, one), one, - lax.pow(uniform(subkey, (), dtype=dtype), lax.div(one, alpha))) + lax.pow(uniform(subkey, (), dtype=dtype), + lax.div(one, alpha))) alpha = lax.select(lax.ge(alpha, one), alpha, lax.add(alpha, one)) d = lax.sub(alpha, one_over_three) @@ -817,7 +884,8 @@ def _gamma_batching_rule(batched_args, batch_dims): def gamma(key: jnp.ndarray, a: RealArray, shape: Optional[Sequence[int]] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Gamma random values with given shape and float dtype. Args: @@ -834,6 +902,8 @@ def gamma(key: jnp.ndarray, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``a.shape``. """ + if prng is not default_prng: + raise NotImplementedError(f"gamma with custom PRNG implementation: {prng}") if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " f"dtype, got {dtype}") @@ -855,17 +925,17 @@ def _gamma(key, a, shape, dtype): return random_gamma_p.bind(key, a) -@partial(jit, static_argnums=(2, 3, 4)) -def _poisson_knuth(key, lam, shape, dtype, max_iters): +@partial(jit, static_argnums=(2, 3, 4, 5)) +def _poisson_knuth(key, lam, shape, dtype, max_iters, prng: PRNG): # Knuth's algorithm for generating Poisson random variates. # Reference: # https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables def body_fn(carry): i, k, rng, log_prod = carry - rng, subkey = split(rng) + rng, subkey = split(rng, prng=prng) k = lax.select(log_prod > -lam, k + 1, k) - u = uniform(subkey, shape, np.float32) + u = uniform(subkey, shape, np.float32, prng=prng) return i + 1, k, rng, log_prod + jnp.log(u) def cond_fn(carry): @@ -878,8 +948,8 @@ def cond_fn(carry): return (k - 1).astype(dtype) -@partial(jit, static_argnums=(2, 3, 4)) -def _poisson_rejection(key, lam, shape, dtype, max_iters): +@partial(jit, static_argnums=(2, 3, 4, 5)) +def _poisson_rejection(key, lam, shape, dtype, max_iters, prng: PRNG): # Transformed rejection due to Hormann. # Reference: # http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054. @@ -891,10 +961,10 @@ def _poisson_rejection(key, lam, shape, dtype, max_iters): def body_fn(carry): i, k_out, accepted, key = carry - key, subkey_0, subkey_1 = split(key, 3) + key, subkey_0, subkey_1 = split(key, 3, prng=prng) - u = uniform(subkey_0, shape, lam.dtype) - 0.5 - v = uniform(subkey_1, shape, lam.dtype) + u = uniform(subkey_0, shape, lam.dtype, prng=prng) - 0.5 + v = uniform(subkey_1, shape, lam.dtype, prng=prng) u_shifted = 0.5 - abs(u) k = lax.floor((2 * a / u_shifted + b) * u + lam + 0.43) @@ -921,8 +991,8 @@ def cond_fn(carry): return k.astype(dtype) -@partial(jit, static_argnums=(2, 3)) -def _poisson(key, lam, shape, dtype): +@partial(jit, static_argnums=(2, 3, 4)) +def _poisson(key, lam, shape, dtype, prng: PRNG): # The implementation matches TensorFlow and NumPy: # https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc # https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574 @@ -936,8 +1006,8 @@ def _poisson(key, lam, shape, dtype): max_iters = dtype.type(jnp.iinfo(dtype).max) # insanely conservative result = lax.select( use_knuth, - _poisson_knuth(key, lam_knuth, shape, dtype, max_iters), - _poisson_rejection(key, lam_rejection, shape, dtype, max_iters), + _poisson_knuth(key, lam_knuth, shape, dtype, max_iters, prng), + _poisson_rejection(key, lam_rejection, shape, dtype, max_iters, prng), ) return lax.select(lam == 0, jnp.zeros_like(result), result) @@ -945,7 +1015,8 @@ def _poisson(key, lam, shape, dtype): def poisson(key: jnp.ndarray, lam: RealArray, shape: Sequence[int] = (), - dtype: DTypeLikeInt = dtypes.int_) -> jnp.ndarray: + dtype: DTypeLikeInt = dtypes.int_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Poisson random values with given shape and integer dtype. Args: @@ -964,12 +1035,13 @@ def poisson(key: jnp.ndarray, if np.shape(lam) != shape: lam = jnp.broadcast_to(lam, shape) lam = lax.convert_element_type(lam, np.float32) - return _poisson(key, lam, shape, dtype) + return _poisson(key, lam, shape, dtype, prng) def gumbel(key: jnp.ndarray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Gumbel random values with given shape and float dtype. Args: @@ -987,19 +1059,21 @@ def gumbel(key: jnp.ndarray, f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _gumbel(key, shape, dtype) + return _gumbel(key, shape, dtype, prng) -@partial(jit, static_argnums=(1, 2)) -def _gumbel(key, shape, dtype): +@partial(jit, static_argnums=(1, 2, 3)) +def _gumbel(key, shape, dtype, prng: PRNG): _check_shape("gumbel", shape) return -jnp.log(-jnp.log( - uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) + uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1., + prng=prng))) def categorical(key: jnp.ndarray, logits: RealArray, axis: int = -1, - shape: Optional[Sequence[int]] = None) -> jnp.ndarray: + shape: Optional[Sequence[int]] = None, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample random values from categorical distributions. Args: @@ -1027,12 +1101,15 @@ def categorical(key: jnp.ndarray, _check_shape("categorical", shape, batch_shape) sample_shape = shape[:len(shape)-len(batch_shape)] - return jnp.argmax(gumbel(key, sample_shape + logits.shape, logits.dtype) + logits, axis=axis) + draws = gumbel( + key, sample_shape + logits.shape, logits.dtype, prng=prng) + logits + return jnp.argmax(draws, axis=axis) def laplace(key: jnp.ndarray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Laplace random values with given shape and float dtype. Args: @@ -1050,19 +1127,21 @@ def laplace(key: jnp.ndarray, f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _laplace(key, shape, dtype) + return _laplace(key, shape, dtype, prng) -@partial(jit, static_argnums=(1, 2)) -def _laplace(key, shape, dtype): +@partial(jit, static_argnums=(1, 2, 3)) +def _laplace(key, shape, dtype, prng: PRNG): _check_shape("laplace", shape) u = uniform( - key, shape, dtype, minval=-1. + jnp.finfo(dtype).epsneg, maxval=1.) + key, shape, dtype, minval=-1. + jnp.finfo(dtype).epsneg, maxval=1., + prng=prng) return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u)))) def logistic(key: jnp.ndarray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample logistic random values with given shape and float dtype. Args: @@ -1080,19 +1159,21 @@ def logistic(key: jnp.ndarray, f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _logistic(key, shape, dtype) + return _logistic(key, shape, dtype, prng) -@partial(jit, static_argnums=(1, 2)) -def _logistic(key, shape, dtype): +@partial(jit, static_argnums=(1, 2, 3)) +def _logistic(key, shape, dtype, prng: PRNG): _check_shape("logistic", shape) - x = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.) + x = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1., + prng=prng) return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x))) def pareto(key: jnp.ndarray, b: RealArray, shape: Optional[Sequence[int]] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Pareto random values with given shape and float dtype. Args: @@ -1115,24 +1196,25 @@ def pareto(key: jnp.ndarray, dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) - return _pareto(key, b, shape, dtype) + return _pareto(key, b, shape, dtype, prng) -@partial(jit, static_argnums=(2, 3)) -def _pareto(key, b, shape, dtype): +@partial(jit, static_argnums=(2, 3, 4)) +def _pareto(key, b, shape, dtype, prng: PRNG): if shape is None: shape = np.shape(b) else: _check_shape("pareto", shape) b = lax.convert_element_type(b, dtype) - e = exponential(key, shape, dtype) + e = exponential(key, shape, dtype, prng=prng) return lax.exp(e / b) def t(key: jnp.ndarray, df: RealArray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample Student's t random values with given shape and float dtype. Args: @@ -1154,27 +1236,28 @@ def t(key: jnp.ndarray, f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _t(key, df, shape, dtype) + return _t(key, df, shape, dtype, prng) -@partial(jit, static_argnums=(2, 3)) -def _t(key, df, shape, dtype): +@partial(jit, static_argnums=(2, 3, 4)) +def _t(key, df, shape, dtype, prng): if shape is None: shape = np.shape(df) else: _check_shape("t", shape, np.shape(df)) df = lax.convert_element_type(df, dtype) - key_n, key_g = split(key) - n = normal(key_n, shape, dtype) + key_n, key_g = split(key, prng=prng) + n = normal(key_n, shape, dtype, prng=prng) two = _constant_like(n, 2) half_df = lax.div(df, two) - g = gamma(key_n, half_df, shape, dtype) + g = gamma(key_n, half_df, shape, dtype, prng=prng) return n * jnp.sqrt(half_df / g) def rademacher(key: jnp.ndarray, shape: Sequence[int], - dtype: DTypeLikeInt = dtypes.int_) -> jnp.ndarray: + dtype: DTypeLikeInt = dtypes.int_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample from a Rademacher distribution. Args: @@ -1189,18 +1272,19 @@ def rademacher(key: jnp.ndarray, """ dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _rademacher(key, shape, dtype) + return _rademacher(key, shape, dtype, prng) -@partial(jit, static_argnums=(1, 2)) -def _rademacher(key, shape, dtype): - bernoulli_samples = bernoulli(key=key, p=0.5, shape=shape) +@partial(jit, static_argnums=(1, 2, 3)) +def _rademacher(key, shape, dtype, prng: PRNG): + bernoulli_samples = bernoulli(key=key, p=0.5, shape=shape, prng=prng) return (2 * bernoulli_samples - 1).astype(dtype) def maxwell(key: jnp.ndarray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample from a one sided Maxwell distribution. The scipy counterpart is `scipy.stats.maxwell`. @@ -1221,13 +1305,13 @@ def maxwell(key: jnp.ndarray, f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _maxwell(key, shape, dtype) + return _maxwell(key, shape, dtype, prng) -@partial(jit, static_argnums=(1, 2)) -def _maxwell(key, shape, dtype): +@partial(jit, static_argnums=(1, 2, 3)) +def _maxwell(key, shape, dtype, prng: PRNG): shape = shape + (3,) - norm_rvs = normal(key=key, shape=shape, dtype=dtype) + norm_rvs = normal(key=key, shape=shape, dtype=dtype, prng=prng) return jnp.linalg.norm(norm_rvs, axis=-1) @@ -1235,7 +1319,8 @@ def double_sided_maxwell(key: jnp.ndarray, loc: RealArray, scale: RealArray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample from a double sided Maxwell distribution. Samples using: @@ -1257,20 +1342,20 @@ def double_sided_maxwell(key: jnp.ndarray, f" dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _double_sided_maxwell(key, loc, scale, shape, dtype) + return _double_sided_maxwell(key, loc, scale, shape, dtype, prng) -@partial(jit, static_argnums=(3, 4)) -def _double_sided_maxwell(key, loc, scale, shape, dtype): +@partial(jit, static_argnums=(3, 4, 5)) +def _double_sided_maxwell(key, loc, scale, shape, dtype, prng: PRNG): params_shapes = lax.broadcast_shapes(np.shape(loc), np.shape(scale)) if not shape: shape = params_shapes shape = shape + params_shapes - maxwell_key, rademacher_key = split(key) - maxwell_rvs = maxwell(maxwell_key, shape=shape, dtype=dtype) + maxwell_key, rademacher_key = split(key, prng=prng) + maxwell_rvs = maxwell(maxwell_key, shape=shape, dtype=dtype, prng=prng) # Generate random signs for the symmetric variates. - random_sign = rademacher(rademacher_key, shape=shape, dtype=dtype) + random_sign = rademacher(rademacher_key, shape=shape, dtype=dtype, prng=prng) assert random_sign.shape == maxwell_rvs.shape return random_sign * maxwell_rvs * scale + loc @@ -1280,7 +1365,8 @@ def weibull_min(key: jnp.ndarray, scale: RealArray, concentration: RealArray, shape: Sequence[int] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray: + dtype: DTypeLikeFloat = dtypes.float_, + prng: PRNG = default_prng) -> jnp.ndarray: """Sample from a Weibull distribution. The scipy counterpart is `scipy.stats.weibull_min`. @@ -1301,13 +1387,13 @@ def weibull_min(key: jnp.ndarray, f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _weibull_min(key, scale, concentration, shape, dtype) + return _weibull_min(key, scale, concentration, shape, dtype, prng) -@partial(jit, static_argnums=(3, 4)) -def _weibull_min(key, scale, concentration, shape, dtype): +@partial(jit, static_argnums=(3, 4, 5)) +def _weibull_min(key, scale, concentration, shape, dtype, prng): random_uniform = uniform( - key=key, shape=shape, minval=0, maxval=1, dtype=dtype) + key=key, shape=shape, minval=0, maxval=1, dtype=dtype, prng=prng) # Inverse weibull CDF. return jnp.power(-jnp.log1p(-random_uniform), 1.0/concentration) * scale diff --git a/jax/random.py b/jax/random.py index 369679564755..8e6d77a81cf1 100644 --- a/jax/random.py +++ b/jax/random.py @@ -79,14 +79,12 @@ # flake8: noqa: F401 from jax._src.prng import ( - PRNGKey, - fold_in, - split, threefry2x32_p, threefry_2x32, ) from jax._src.random import ( + PRNGKey, bernoulli, beta, categorical, @@ -95,6 +93,7 @@ dirichlet, double_sided_maxwell, exponential, + fold_in, gamma, gumbel, laplace, @@ -109,6 +108,7 @@ randint, random_gamma_p, shuffle, + split, t, truncated_normal, uniform, diff --git a/tests/random_test.py b/tests/random_test.py index c3c64a6a04ca..3a4fcc35f7c9 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -133,7 +133,8 @@ def testRngRandomBitsViewProperty(self): N = 10 key = random.PRNGKey(1701) nbits = [8, 16, 32] - rand_bits = [jax._src.random._random_bits(key, n, (N * 64 // n,)) + prng = jax._src.random.default_prng + rand_bits = [jax._src.random._random_bits(key, n, (N * 64 // n,), prng) for n in nbits] rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits]) assert np.all(rand_bits_32 == rand_bits_32[0]) @@ -141,21 +142,22 @@ def testRngRandomBitsViewProperty(self): def testRngRandomBits(self): # Test specific outputs to ensure consistent random values between JAX versions. key = random.PRNGKey(1701) + prng = jax._src.random.default_prng - bits8 = jax._src.random._random_bits(key, 8, (3,)) + bits8 = jax._src.random._random_bits(key, 8, (3,), prng) expected8 = np.array([216, 115, 43], dtype=np.uint8) self.assertArraysEqual(bits8, expected8) - bits16 = jax._src.random._random_bits(key, 16, (3,)) + bits16 = jax._src.random._random_bits(key, 16, (3,), prng) expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) self.assertArraysEqual(bits16, expected16) - bits32 = jax._src.random._random_bits(key, 32, (3,)) + bits32 = jax._src.random._random_bits(key, 32, (3,), prng) expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) self.assertArraysEqual(bits32, expected32) with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): - bits64 = jax._src.random._random_bits(key, 64, (3,)) + bits64 = jax._src.random._random_bits(key, 64, (3,), prng) if config.x64_enabled: expected64 = np.array([3982329540505020460, 16822122385914693683, 7882654074788531506], dtype=np.uint64)