Skip to content

Commit

Permalink
Merge pull request #8067 from google:rbg-prng
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 400341853
  • Loading branch information
jax authors committed Oct 2, 2021
2 parents e01dc5d + 980dfcf commit 07083e7
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 7 deletions.
30 changes: 25 additions & 5 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6741,24 +6741,44 @@ def _rng_uniform_translation_rule(c, a, b, *, shape):


def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
_ = dtype, algorithm
del dtype, algorithm
return (key.shape, tuple(shape))


def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
_ = key, shape, algorithm
del shape, algorithm
return (key.dtype, dtype)


def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
_ = shape, dtype, algorithm
del shape, dtype, algorithm
return (key.weak_type, False)


def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm):
_ = c
key_shape, key_dtype = c.get_shape(key).dimensions(), c.get_shape(key).numpy_dtype()
# While the RngBitGenerator HLO accepts a u64[2] key on all backends, we
# typically represent the key argument to this primitive as a u32[4] so as to
# sidestep issues with the jax_enable_x64=False configuration. As a result, we
# need to convert u32[4] -> u64[2] here in the translation rule. However, we
# also polymorphically allow a u64[2] for backward compatibility.
assert ((key_shape == (4,) and key_dtype == dtypes.dtype('uint32')) or
(key_shape == (2,) and key_dtype == dtypes.dtype('uint64')))
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
return xops.RngBitGenerator(algorithm, key, xla_shape)
if key_dtype == dtypes.dtype('uint32'):
u64_etype = xla_client.dtype_to_etype(dtypes.dtype('uint64'))
# TODO(mattjj): use BitcastConvertType implementation with newer jaxlib
# new_key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
new_key = xla_bridge.constant(c, np.zeros(2, dtype=np.dtype('uint64')),
canonicalize_types=False)
for i in range(4):
elt = xops.ConvertElementType(xops.Slice(key, [i], [i+1], [1]), u64_etype)
if i % 2 == 0:
elt = xops.ShiftLeft(elt, xla_bridge.constant(c, np.uint64(32), canonicalize_types=False))
new_key = xops.DynamicUpdateSlice(new_key, elt, [xla_bridge.constant(c, i // 2)])
return xops.RngBitGenerator(algorithm, new_key, xla_shape)
else:
return xops.RngBitGenerator(algorithm, key, xla_shape)


def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
Expand Down
40 changes: 38 additions & 2 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


from functools import partial
from typing import Callable, Iterator, NamedTuple
from typing import Callable, Iterator, NamedTuple, Sequence
import warnings

import numpy as np
Expand Down Expand Up @@ -431,7 +431,7 @@ def _threefry_fold_in(key, data):
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
if not _is_threefry_prng_key(key):
raise TypeError("_random_bits got invalid prng key.")
raise TypeError("threefry_random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
shape = core.as_named_shape(shape)
Expand Down Expand Up @@ -490,3 +490,39 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
split=threefry_split,
random_bits=threefry_random_bits,
fold_in=threefry_fold_in)


# -- RngBitGenerator PRNG implementation --

# This code is experimental!
# https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
# Notice that the RngBitGenerator operations are not guaranteed to be
# stable/deterministic across backends or compiler versions. Correspondingly, we
# reserve the right to change any of these implementations at any time!

def _rbg_seed(seed: int) -> jnp.ndarray:
halfkey = threefry_seed(seed)
return jnp.concatenate([halfkey, halfkey])

def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
_, keys = lax.rng_bit_generator(key, (num, 4), dtype='uint32')
return keys

def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
) -> jnp.ndarray:
if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'):
raise TypeError("_rbg_random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
_, bits = lax.rng_bit_generator(key, shape, dtype=UINT_DTYPES[bit_width])
return bits

def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
return jnp.uint32(data) ^ key

rbg_prng_impl = PRNGImpl(
key_shape=(4,),
seed=_rbg_seed,
split=_rbg_split,
random_bits=_rbg_random_bits,
fold_in=_rbg_fold_in)
1 change: 1 addition & 0 deletions jax/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
threefry2x32_p as threefry2x32_p,
threefry_2x32 as threefry_2x32,
threefry_prng_impl as threefry_prng_impl,
rbg_prng_impl as rbg_prng_impl,
)
46 changes: 46 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,49 @@ def test_grad_of_prng_key(self):
key = self.seed_prng(73)
jax.grad(lambda x: 1., allow_int=True)(key) # does not crash

@skipIf(not config.jax_enable_custom_prng,
'custom PRNG tests require config.jax_enable_custom_prng')
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return prng.seed_with_impl(prng.rbg_prng_impl, seed)

def test_split_shape(self):
key = self.seed_prng(73)
keys = random.split(key, 10)
self.assertEqual(keys.shape, (10,))

def test_vmap_fold_in_shape(self):
key = self.seed_prng(73)
keys = vmap(lambda i: random.fold_in(key, i))(jnp.arange(3))
self.assertEqual(keys.shape, (3,))

def test_cannot_add(self):
key = self.seed_prng(73)
self.assertRaisesRegex(
TypeError, r'unsupported operand type\(s\) for \+*',
lambda: key + 47)

@skipIf(np.__version__ == "1.21.0",
"https://github.com/numpy/numpy/issues/19305")
def test_grad_of_prng_key(self):
key = self.seed_prng(73)
jax.grad(lambda x: 1., allow_int=True)(key) # does not crash

def test_random_split_doesnt_device_put_during_tracing(self):
return # this test doesn't apply to the RBG PRNG

def test_randint_out_of_range(self):
# TODO(mattjj): enable this test if/when RngBitGenerator supports it
raise SkipTest('8-bit types not supported with RBG PRNG')

def _sampler_unimplemented_with_rbg(*args, **kwargs):
# TODO(mattjj): enable these tests if/when RngBitGenerator supports them
raise SkipTest('8- and 16-bit types not supported with RBG PRNG')

for attr in dir(LaxRandomWithRBGPRNGTest):
if 'int8' in attr or 'int16' in attr or 'float16' in attr:
setattr(LaxRandomWithRBGPRNGTest, attr, _sampler_unimplemented_with_rbg)

def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')
Expand All @@ -1142,6 +1185,9 @@ def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
if attr.startswith(test_prefix):
setattr(LaxRandomWithCustomPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
setattr(LaxRandomWithRBGPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)



if __name__ == "__main__":
Expand Down

0 comments on commit 07083e7

Please sign in to comment.