Skip to content

Commit

Permalink
parameterize random samplers by PRNG implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Jun 4, 2021
1 parent f0c4912 commit 227b70b
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 197 deletions.
72 changes: 26 additions & 46 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from functools import partial
from typing import Any

import numpy as np

Expand All @@ -30,31 +31,23 @@
from jax._src.util import prod


PRNG = Any
UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}


def PRNGKey(seed: int) -> jnp.ndarray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
Returns:
A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The
key is constructed from a 64-bit seed by effectively bit-casting to a pair
of uint32 values (or from a 32-bit seed by first padding out with zeros).
"""
def threefry_init(seed: int) -> jnp.ndarray:
# Avoid overflowerror in X32 mode by first converting ints to int64.
# This breaks JIT invariance of PRNGKey for large ints, but supports the
# common use-case of instantiating PRNGKey with Python hashes in X32 mode.
# This breaks JIT invariance of this init function for large ints,
# but supports the common use-case of calling it with Python hashes
# in X32 mode.
if isinstance(seed, int):
seed_arr = jnp.asarray(np.int64(seed))
else:
seed_arr = jnp.asarray(seed)
if seed_arr.shape:
raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.")
raise TypeError(f"PRNG seed must be a scalar; got {seed!r}.")
if not np.issubdtype(seed_arr.dtype, np.integer):
raise TypeError(f"PRNGKey seed must be an integer; got {seed!r}")
raise TypeError(f"PRNG seed must be an integer; got {seed!r}")

convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
Expand Down Expand Up @@ -238,50 +231,30 @@ def threefry_2x32(keypair, count):
return lax.reshape(out[:-1] if odd_size else out, count.shape)


def split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
key: a PRNGKey (an array with shape (2,) and dtype uint32).
num: optional, a positive integer indicating the number of keys to produce
(default 2).
Returns:
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
"""
return _split(key, int(num)) # type: ignore
def threefry_split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray:
return _threefry_split(key, int(num)) # type: ignore


@partial(jit, static_argnums=(1,))
def _split(key, num) -> jnp.ndarray:
def _threefry_split(key, num) -> jnp.ndarray:
counts = lax.iota(np.uint32, num * 2)
return lax.reshape(threefry_2x32(key, counts), (num, 2))


def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
key: a PRNGKey (an array with shape (2,) and dtype uint32).
data: a 32bit integer representing data to be folded in to the key.
Returns:
A new PRNGKey that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
return _fold_in(key, jnp.uint32(data))
def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
return _threefry_fold_in(key, jnp.uint32(data))


@jit
def _fold_in(key, data):
return threefry_2x32(key, PRNGKey(data))
def _threefry_fold_in(key, data):
return threefry_2x32(key, threefry_init(data))


@partial(jit, static_argnums=(1, 2))
def _random_bits(key, bit_width, shape):
def threefry_random_bits(key, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
if not _is_prng_key(key):
raise TypeError("_random_bits got invalid prng key.")
raise TypeError("random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
shape = core.as_named_shape(shape)
Expand All @@ -291,15 +264,15 @@ def _random_bits(key, bit_width, shape):
raise ValueError(f"The shape of axis {name} was specified as {size}, "
f"but it really is {real_size}")
axis_index = lax.axis_index(name)
key = fold_in(key, axis_index)
key = threefry_fold_in(key, axis_index)
size = prod(shape.positional)
max_count = int(np.ceil(bit_width * size / 32))

nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
if not nblocks:
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
else:
*subkeys, last_key = split(key, nblocks + 1)
*subkeys, last_key = threefry_split(key, nblocks + 1)
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
for k in subkeys]
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
Expand All @@ -324,3 +297,10 @@ def _random_bits(key, bit_width, shape):
bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width),), (1, 0))
bits = lax.convert_element_type(bits, dtype)[:size]
return lax.reshape(bits, shape)


class threefry_prng:
init = threefry_init
fold_in = threefry_fold_in
random_bits = threefry_random_bits
split = threefry_split
Loading

0 comments on commit 227b70b

Please sign in to comment.