Skip to content

Commit

Permalink
introduce an indirect PRNG key type to support customization, impleme…
Browse files Browse the repository at this point in the history
…nt the default PRNG
  • Loading branch information
froystig committed Jun 16, 2021
1 parent 74b7c5f commit fbd2f6b
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 88 deletions.
92 changes: 60 additions & 32 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax import lax
from jax import core
from jax import numpy as jnp
from jax import tree_util
from jax._src.api import jit
from jax.lib import xla_bridge
from jax.lib import xla_client
Expand All @@ -32,7 +33,55 @@
UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}


def PRNGKey(seed: int) -> jnp.ndarray:
@tree_util.register_pytree_node_class
class PRNGKey:
"""Represents a PRNG key or batch thereof."""

key: jnp.ndarray

def __init__(self, key: jnp.ndarray):
# key might be a dummy object due to tree_unflatten
ndim = getattr(key, 'ndim', 1)
dtype = getattr(key, 'dtype', np.uint32)
if ndim < 1 or dtype != np.uint32:
raise TypeError(
f'invalid prng key or key batch: ndim = {ndim}, dtype = {dtype}')
self.key = key

def tree_flatten(self):
return (self.key,), None

@classmethod
def tree_unflatten(cls, _, key):
key, = key
return cls(key)

def fold_in(self, data: int) -> 'PRNGKey':
return PRNGKey(_fold_in(self.key, data))

def random_bits(self, bit_width, shape) -> jnp.ndarray:
return _random_bits(self.key, bit_width, shape)

def split(self, num: int) -> 'PRNGKey':
return PRNGKey(_split(self.key, num))

def __iter__(self):
assert self.key.ndim > 0
if self.key.ndim == 1:
raise TypeError('iteration over a single PRNG key')
return (PRNGKey(k) for k in self.key.__iter__())

# TODO(frostig): remove if possible, otherwise declare it necessary
@property
def shape(self):
return self.key.shape


def make_prng_key(seed: int) -> PRNGKey:
return PRNGKey(_threefry_prng_key(seed))


def _threefry_prng_key(seed: int) -> jnp.ndarray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
Args:
Expand Down Expand Up @@ -67,7 +116,6 @@ def _is_prng_key(key: jnp.ndarray) -> bool:
except AttributeError:
return False


def _make_rotate_left(dtype):
if not jnp.issubdtype(dtype, np.integer):
raise TypeError("_rotate_left only accepts integer dtypes.")
Expand Down Expand Up @@ -237,47 +285,27 @@ def threefry_2x32(keypair, count):
return lax.reshape(out[:-1] if odd_size else out, count.shape)


def split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
key: a PRNGKey (an array with shape (2,) and dtype uint32).
num: optional, a positive integer indicating the number of keys to produce
(default 2).
Returns:
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
"""
return _split(key, int(num)) # type: ignore
def _split(key: jnp.ndarray, num: int) -> jnp.ndarray:
return _threefry_split(key, int(num)) # type: ignore


@partial(jit, static_argnums=(1,))
def _split(key, num) -> jnp.ndarray:
def _threefry_split(key, num) -> jnp.ndarray:
counts = lax.iota(np.uint32, num * 2)
return lax.reshape(threefry_2x32(key, counts), (num, 2))


def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
key: a PRNGKey (an array with shape (2,) and dtype uint32).
data: a 32bit integer representing data to be folded in to the key.
Returns:
A new PRNGKey that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
return _fold_in(key, jnp.uint32(data))
def _fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
return _threefry_fold_in(key, jnp.uint32(data))


@jit
def _fold_in(key, data):
return threefry_2x32(key, PRNGKey(data))
def _threefry_fold_in(key, data):
return threefry_2x32(key, _threefry_prng_key(data))


@partial(jit, static_argnums=(1, 2))
def _random_bits(key, bit_width, shape):
def _random_bits(key: jnp.ndarray, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
if not _is_prng_key(key):
raise TypeError("_random_bits got invalid prng key.")
Expand All @@ -290,15 +318,15 @@ def _random_bits(key, bit_width, shape):
raise ValueError(f"The shape of axis {name} was specified as {size}, "
f"but it really is {real_size}")
axis_index = lax.axis_index(name)
key = fold_in(key, axis_index)
key = _fold_in(key, axis_index)
size = prod(shape.positional)
max_count = int(np.ceil(bit_width * size / 32))

nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
if not nblocks:
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
else:
*subkeys, last_key = split(key, nblocks + 1)
*subkeys, last_key = _split(key, nblocks + 1)
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
for k in subkeys]
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
Expand Down
Loading

0 comments on commit fbd2f6b

Please sign in to comment.