Skip to content

Commit

Permalink
factor PRNG routines from random module to prng
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Jun 4, 2021
1 parent 98e4e40 commit db0ccc7
Show file tree
Hide file tree
Showing 3 changed files with 338 additions and 293 deletions.
326 changes: 326 additions & 0 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from functools import partial

import numpy as np

from jax import lax
from jax import core
from jax import numpy as jnp
from jax._src.api import jit
from jax._src.numpy.lax_numpy import asarray
from jax.lib import xla_bridge
from jax.lib import xla_client
from jax.lib import cuda_prng
from jax.interpreters import batching
from jax.interpreters import xla
from jax._src.util import prod


UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}


def PRNGKey(seed: int) -> jnp.ndarray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
Returns:
A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The
key is constructed from a 64-bit seed by effectively bit-casting to a pair
of uint32 values (or from a 32-bit seed by first padding out with zeros).
"""
# Avoid overflowerror in X32 mode by first converting ints to int64.
# This breaks JIT invariance of PRNGKey for large ints, but supports the
# common use-case of instantiating PRNGKey with Python hashes in X32 mode.
if isinstance(seed, int):
seed_arr = jnp.asarray(np.int64(seed))
else:
seed_arr = jnp.asarray(seed)
if seed_arr.shape:
raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.")
if not np.issubdtype(seed_arr.dtype, np.integer):
raise TypeError(f"PRNGKey seed must be an integer; got {seed!r}")

convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
return lax.concatenate([k1, k2], 0)


def _is_prng_key(key: jnp.ndarray) -> bool:
try:
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False


def _make_rotate_left(dtype):
if not jnp.issubdtype(dtype, np.integer):
raise TypeError("_rotate_left only accepts integer dtypes.")
nbits = np.array(jnp.iinfo(dtype).bits, dtype)

def _rotate_left(x, d):
if lax.dtype(d) != dtype:
d = lax.convert_element_type(d, dtype)
if lax.dtype(x) != dtype:
x = lax.convert_element_type(x, dtype)
return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d)
return _rotate_left


def _bit_stats(bits):
"""This is a debugging function to compute the statistics of bit fields."""
return np.array([list(map(int, np.binary_repr(x, 64))) for x in bits]).mean(0)


### hash function and split

def _threefry2x32_abstract_eval(*args):
if any(a.dtype != jnp.uint32 for a in args):
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
.format(args))
if all(isinstance(arg, core.ShapedArray) for arg in args):
shape = lax._broadcasting_shape_rule(*args)
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
else:
aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
return (aval,) * 2


rotate_left = _make_rotate_left(np.uint32)


def apply_round(v, rot):
v = v[:]
v[0] = v[0] + v[1]
v[1] = rotate_left(v[1], rot)
v[1] = v[0] ^ v[1]
return v


def rotate_list(xs):
return xs[1:] + xs[:1]


def rolled_loop_step(i, state):
x, ks, rotations = state
for r in rotations[0]:
x = apply_round(x, r)
new_x = [x[0] + ks[0], x[1] + ks[1] + asarray(i + 1, dtype=np.uint32)]
return new_x, rotate_list(ks), rotate_list(rotations)


def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
"""Apply the Threefry 2x32 hash.
Args:
keypair: a pair of 32bit unsigned integers used for the key.
count: an array of dtype uint32 used for the counts.
Returns:
An array of dtype uint32 with the same shape as `count`.
"""
x = [x1, x2]

rotations = [np.array([13, 15, 26, 6], dtype=np.uint32),
np.array([17, 29, 16, 24], dtype=np.uint32)]
ks = [key1, key2, key1 ^ key2 ^ np.uint32(0x1BD11BDA)]

x[0] = x[0] + ks[0]
x[1] = x[1] + ks[1]

if use_rolled_loops:
x, _, _ = lax.fori_loop(0, 5, rolled_loop_step, (x, rotate_list(ks), rotations))

else:
for r in rotations[0]:
x = apply_round(x, r)
x[0] = x[0] + ks[1]
x[1] = x[1] + ks[2] + np.uint32(1)

for r in rotations[1]:
x = apply_round(x, r)
x[0] = x[0] + ks[2]
x[1] = x[1] + ks[0] + np.uint32(2)

for r in rotations[0]:
x = apply_round(x, r)
x[0] = x[0] + ks[0]
x[1] = x[1] + ks[1] + np.uint32(3)

for r in rotations[1]:
x = apply_round(x, r)
x[0] = x[0] + ks[1]
x[1] = x[1] + ks[2] + np.uint32(4)

for r in rotations[0]:
x = apply_round(x, r)
x[0] = x[0] + ks[2]
x[1] = x[1] + ks[0] + np.uint32(5)

return tuple(x)


def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
shape = lax.broadcast_shapes(
c.get_shape(k1).dimensions(), c.get_shape(k2).dimensions(),
c.get_shape(x1).dimensions(), c.get_shape(x2).dimensions())
rank = len(shape)
if 0 in shape:
zeros = xla_client.ops.Broadcast(
xla_bridge.constant(c, np.array(0, np.uint32)), shape)
return xla_client.ops.Tuple(c, [zeros, zeros])
def _broadcast(x):
ndims = c.get_shape(x).rank()
return xla_client.ops.BroadcastInDim(x, shape,
tuple(range(rank - ndims, rank)))
return cuda_prng.threefry2x32(
c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations_with_avals[threefry2x32_p] = xla.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True, with_avals=True)
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True)
if cuda_prng:
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
_threefry2x32_gpu_translation_rule


@jit
def threefry_2x32(keypair, count):
"""Apply the Threefry 2x32 hash.
Args:
keypair: a pair of 32bit unsigned integers used for the key.
count: an array of dtype uint32 used for the counts.
Returns:
An array of dtype uint32 with the same shape as `count`.
"""
key1, key2 = keypair
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32:
msg = "threefry_2x32 requires uint32 arguments, got {}"
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))

odd_size = count.size % 2
if odd_size:
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
else:
x = list(jnp.split(count.ravel(), 2))

x = threefry2x32_p.bind(key1, key2, x[0], x[1])
out = jnp.concatenate(x)
assert out.dtype == np.uint32
return lax.reshape(out[:-1] if odd_size else out, count.shape)


def split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
key: a PRNGKey (an array with shape (2,) and dtype uint32).
num: optional, a positive integer indicating the number of keys to produce
(default 2).
Returns:
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
"""
return _split(key, int(num)) # type: ignore


@partial(jit, static_argnums=(1,))
def _split(key, num) -> jnp.ndarray:
counts = lax.iota(np.uint32, num * 2)
return lax.reshape(threefry_2x32(key, counts), (num, 2))


def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
key: a PRNGKey (an array with shape (2,) and dtype uint32).
data: a 32bit integer representing data to be folded in to the key.
Returns:
A new PRNGKey that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
return _fold_in(key, jnp.uint32(data))


@jit
def _fold_in(key, data):
return threefry_2x32(key, PRNGKey(data))


@partial(jit, static_argnums=(1, 2))
def _random_bits(key, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
if not _is_prng_key(key):
raise TypeError("_random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
shape = core.as_named_shape(shape)
for name, size in shape.named_items:
real_size = lax.psum(1, name)
if real_size != size:
raise ValueError(f"The shape of axis {name} was specified as {size}, "
f"but it really is {real_size}")
axis_index = lax.axis_index(name)
key = fold_in(key, axis_index)
size = prod(shape.positional)
max_count = int(np.ceil(bit_width * size / 32))

nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
if not nblocks:
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
else:
*subkeys, last_key = split(key, nblocks + 1)
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
for k in subkeys]
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
bits = lax.concatenate(blocks + [last], 0)

dtype = UINT_DTYPES[bit_width]
if bit_width == 64:
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
elif bit_width in [8, 16]:
# this is essentially bits.view(dtype)[:size]
bits = lax.bitwise_and(
np.uint32(np.iinfo(dtype).max),
lax.shift_right_logical(
lax.broadcast(bits, (1,)),
lax.mul(
np.uint32(bit_width),
lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)
)
)
)
bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width),), (1, 0))
bits = lax.convert_element_type(bits, dtype)[:size]
return lax.reshape(bits, shape)
Loading

0 comments on commit db0ccc7

Please sign in to comment.