Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add experimental RngBitGenerator ("RBG") PRNG #8067

Merged
merged 1 commit into from
Oct 2, 2021
Merged

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Oct 2, 2021

Builds on the awesome #6899 upgrade.

Not only is this an experimental API, but also because the RngBitGenerator is not guaranteed to be stable across compiler versions (or backends or shardings), let's assert that this JAX PRNG implementation may not be stable across JAX versions.

Even without that kind of stability, this PRNG is still useful because unlike the effectful RNG primitives, such as lax.rng_uniform, at a given jax/jaxlib version this RBG PRNG will still work correctly with lax.scan and jax.checkpoint (while still potentially being more performant on some platforms than JAX's standard PRNG).

Thanks to #6899, it's super easy to use, just by adapting key creation to use prng.seed_with_impl and then using the jax.random API as usual:

In [1]: from jax import prng
In [2]: key = prng.seed_with_impl(prng.rbg_prng_impl, 87)

In [3]: import jax

In [4]: key, subkey = jax.random.split(key)

In [5]: jax.random.uniform(subkey)
Out[5]: DeviceArray(0.62572646, dtype=float32)

cc @levskaya @jekbradbury @zhangqiaorjc

Here's an illustration of how things can go horribly wrong with the effectful lax.rng_uniform:

import jax
from jax import lax
import jax.numpy as jnp

def scanned_fun(_, x):
  r = lax.rng_uniform(0., 1., ())
  return None, x * r

# without differentiation we're okay
_, ys = lax.scan(scanned_fun, None, jnp.ones(5))
print(ys)
# [0.59189475 0.00664103 0.17746115 0.23565137 0.69601536]

# grad-of-scan-of-effectful-rng = danger
gs = jax.grad(lambda xs: lax.scan(scanned_fun, None, xs)[1].sum())(jnp.ones(5))
print(gs)
# [0.90172756 0.90172756 0.90172756 0.90172756 0.90172756]

# remat doesn't save you
gs = jax.grad(lambda xs: lax.scan(
    jax.remat(scanned_fun), None, xs)[1].sum())(jnp.ones(5))
print(gs)
# [0.8394314 0.8394314 0.8394314 0.8394314 0.8394314]

# forward pass is messed up too (though only after differentiation)
ys, vjp = jax.vjp(lambda xs: lax.scan(scanned_fun, None, xs)[1], jnp.ones(5))
gs, = vjp(jnp.ones(5))
print(ys)
print(gs)
# [0.32732797 0.32732797 0.32732797 0.32732797 0.32732797]
# [0.32732797 0.32732797 0.32732797 0.32732797 0.32732797]

Here's RBG helping us out:

import jax
from jax import lax
from jax import prng
from jax import random
import jax.numpy as jnp

def scanned_fun(key, x):
  key, subkey = random.split(key)
  r = random.uniform(subkey, ())
  return key, x * r

key = prng.seed_with_impl(prng.rbg_prng_impl, 87)
_, ys = lax.scan(scanned_fun, key, jnp.ones(5))
print(ys)
# [0.62572646 0.34407544 0.94847405 0.15406036 0.3797592 ]

gs = jax.grad(lambda xs: lax.scan(scanned_fun, key, xs)[1].sum())(jnp.ones(5))
print(gs)
# [0.62572646 0.34407544 0.94847405 0.15406036 0.3797592 ]

gs = jax.grad(lambda xs: lax.scan(jax.remat(scanned_fun), key, xs)[1].sum())(jnp.ones(5))
print(gs)
# [0.62572646 0.34407544 0.94847405 0.15406036 0.3797592 ]


ys, vjp = jax.vjp(lambda xs: lax.scan(scanned_fun, key, xs)[1], jnp.ones(5))
gs, = vjp(jnp.ones(5))
print(ys)
print(gs)
# [0.62572646 0.34407544 0.94847405 0.15406036 0.3797592 ]
# [0.62572646 0.34407544 0.94847405 0.15406036 0.3797592 ]

Ah, precisely the kind of reproducibility we want: different across iterations, the same across the forward and backward passes (and repeatable too!).

@mattjj mattjj requested a review from froystig October 2, 2021 01:20
@google-cla google-cla bot added the cla: yes label Oct 2, 2021
jax/_src/lax/lax.py Outdated Show resolved Hide resolved
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Oct 2, 2021
Not only is this an experimental API, but also because the
RngBitGenerator is not guaranteed to be stable across compiler versions
(or backends or shardings), let's assert that this JAX PRNG
implementation may not be stable across JAX versions.

Even without that kind of stability, this PRNG is still useful because
compared to effectful RNG primitives, like lax.rng_uniform, this RBG
PRNG will still work correctly with lax.scan and jax.checkpoint (while
still potentially being more performant on some platforms than JAX's
standard PRNG).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants