diff --git a/tests/random_test.py b/tests/random_test.py index 72fbdc0355f3..34a679e69484 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1051,5 +1051,63 @@ def test_randint_out_of_range(self): self.assertGreater((r == 255).sum(), 0) +threefry_seed = jax._src.prng.threefry_seed +threefry_split = jax._src.prng.threefry_split +threefry_random_bits = jax._src.prng.threefry_random_bits +threefry_fold_in = jax._src.prng.threefry_fold_in + +def _double_threefry_seed(seed): + return jnp.vstack([threefry_seed(seed + 1), + threefry_seed(seed + 2)]) + +def _double_threefry_split(key, num): + split0 = threefry_split(key[0], num) + split1 = threefry_split(key[1], num) + merge = jnp.vstack([jnp.expand_dims(split0.T, axis=0), + jnp.expand_dims(split1.T, axis=0)]) + return merge.transpose((2, 0, 1)) + +def _double_threefry_random_bits(key, bit_width, shape): + bits0 = threefry_random_bits(key[0], bit_width, shape) + bits1 = threefry_random_bits(key[1], bit_width, shape) + return bits0 * bits1 + +def _double_threefry_fold_in(key, data): + return jnp.vstack([threefry_fold_in(key[0], data), + threefry_fold_in(key[1], data)]) + +double_threefry_prng_impl = prng.PRNGImpl( + key_shape=(2, 2), + seed=_double_threefry_seed, + split=_double_threefry_split, + random_bits=_double_threefry_random_bits, + fold_in=_double_threefry_fold_in) + +@jtu.with_config(jax_numpy_rank_promotion="raise") +class LaxRandomWithCustomPRNGTest(LaxRandomTest): + def seed_prng(self, seed): + return prng.seed_with_impl(double_threefry_prng_impl, seed) + +def _sampler_unimplemented_with_custom_prng(*args, **kwargs): + raise SkipTest('sampler only implemented for default RNG') + +for test_prefix in [ + 'testDirichlet', + 'testGamma', + 'testGammaGrad', + 'testGammaGradType', + 'testGammaShape', + 'testIssue1789', + 'testPoisson', + 'testPoissonBatched', + 'testPoissonShape', + 'testPoissonZeros', +]: + for attr in dir(LaxRandomTest): + if attr.startswith(test_prefix): + setattr(LaxRandomWithCustomPRNGTest, attr, + _sampler_unimplemented_with_custom_prng) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())