Skip to content

Commit

Permalink
test under custom prng
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Aug 14, 2021
1 parent 5d1bc1b commit 7993c0e
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,5 +1051,63 @@ def test_randint_out_of_range(self):
self.assertGreater((r == 255).sum(), 0)


threefry_seed = jax._src.prng.threefry_seed
threefry_split = jax._src.prng.threefry_split
threefry_random_bits = jax._src.prng.threefry_random_bits
threefry_fold_in = jax._src.prng.threefry_fold_in

def _double_threefry_seed(seed):
return jnp.vstack([threefry_seed(seed + 1),
threefry_seed(seed + 2)])

def _double_threefry_split(key, num):
split0 = threefry_split(key[0], num)
split1 = threefry_split(key[1], num)
merge = jnp.vstack([jnp.expand_dims(split0.T, axis=0),
jnp.expand_dims(split1.T, axis=0)])
return merge.transpose((2, 0, 1))

def _double_threefry_random_bits(key, bit_width, shape):
bits0 = threefry_random_bits(key[0], bit_width, shape)
bits1 = threefry_random_bits(key[1], bit_width, shape)
return bits0 * bits1

def _double_threefry_fold_in(key, data):
return jnp.vstack([threefry_fold_in(key[0], data),
threefry_fold_in(key[1], data)])

double_threefry_prng_impl = prng.PRNGImpl(
key_shape=(2, 2),
seed=_double_threefry_seed,
split=_double_threefry_split,
random_bits=_double_threefry_random_bits,
fold_in=_double_threefry_fold_in)

@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return prng.seed_with_impl(double_threefry_prng_impl, seed)

def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')

for test_prefix in [
'testDirichlet',
'testGamma',
'testGammaGrad',
'testGammaGradType',
'testGammaShape',
'testIssue1789',
'testPoisson',
'testPoissonBatched',
'testPoissonShape',
'testPoissonZeros',
]:
for attr in dir(LaxRandomTest):
if attr.startswith(test_prefix):
setattr(LaxRandomWithCustomPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 7993c0e

Please sign in to comment.