Skip to content

Commit

Permalink
implement backwards-compatible behavior and enable custom PRNGs only …
Browse files Browse the repository at this point in the history
…conditionally

Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
  • Loading branch information
froystig committed Aug 18, 2021
1 parent 96b3b1e commit def2b8d
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 89 deletions.
2 changes: 1 addition & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
del _config_module

from ._src.config import (
config, enable_checks, check_tracer_leaks, checking_leaks,
config, enable_checks, check_tracer_leaks, checking_leaks, enable_custom_prng,
debug_nans, debug_infs, log_compiles, default_matmul_precision,
numpy_rank_promotion
)
Expand Down
9 changes: 9 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,15 @@ def update_thread_local_jit_state(**kw):
'computations. Logging is performed with `absl.logging` at WARNING '
'level.'))

enable_custom_prng = config.define_bool_state(
name='jax_enable_custom_prng',
default=False,
help=('Enables an internal upgrade that allows one to define custom '
'pseudo-random number generator implementations. This will '
'be enabled by default in future versions of JAX, at which point '
'disabling it will be considered deprecated. In a version '
'after that the flag will be removed altogether.'))

hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,
Expand Down
56 changes: 36 additions & 20 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax import numpy as jnp
from jax import tree_util
from jax._src.api import jit, vmap
from jax.config import config
from jax.lib import xla_bridge
from jax.lib import xla_client
from jax.lib import cuda_prng
Expand Down Expand Up @@ -98,27 +99,43 @@ def tree_unflatten(cls, impl, keys):
keys, = keys
return cls(impl, keys)

# TODO(frostig): Modify or remove after deprecation window. The
# change to consider is to return the leading shape, up to the
# individual key dimensions, i.e.:
# base_ndim = len(self.impl.key_shape)
# return self.keys.shape[:-base_ndim]
@property
def dtype(self):
# TODO(frostig): remove after deprecation window
if config.jax_enable_custom_prng:
raise AttributeError("'PRNGKeyArray' has no attribute 'dtype'")
else:
warnings.warn(
'deprecated `dtype` attribute of PRNG key arrays', FutureWarning)
return np.uint32

def shape(self):
warnings.warn(
'deprecated `shape` attribute of PRNG key arrays. In a future version '
'of JAX this attribute will be removed or its value may change.')
return self.keys.shape
# TODO(frostig): simplify once we always enable_custom_prng
if config.jax_enable_custom_prng:
return self._shape
else:
warnings.warn(
'deprecated `shape` attribute of PRNG key arrays. In a future version '
'of JAX this attribute will be removed or its value may change.',
FutureWarning)
return self.keys.shape

# TODO(frostig): remove after deprecation window
@property
def dtype(self):
warnings.warn('deprecated `dtype` attribute of PRNG key arrays')
return np.uint32
def _shape(self):
base_ndim = len(self.impl.key_shape)
return self.keys.shape[:-base_ndim]

def __iter__(self) -> Iterator['PRNGKeyArray']:
def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
if self.keys.ndim == base_ndim:
return self.keys.ndim == base_ndim

def __len__(self):
if self._is_scalar():
raise TypeError('len() of unsized object')
return len(self.keys)

def __iter__(self) -> Iterator['PRNGKeyArray']:
if self._is_scalar():
raise TypeError('iteration over a 0-d single PRNG key')
return (PRNGKeyArray(self.impl, k) for k in iter(self.keys))

Expand All @@ -137,18 +154,17 @@ def __getitem__(self, idx) -> 'PRNGKeyArray':
f'but {len(idx)} were indexed')
return PRNGKeyArray(self.impl, self.keys[idx])

def fold_in(self, data: int) -> 'PRNGKeyArray':
def _fold_in(self, data: int) -> 'PRNGKeyArray':
return PRNGKeyArray(self.impl, self.impl.fold_in(self.keys, data))

def random_bits(self, bit_width, shape) -> jnp.ndarray:
def _random_bits(self, bit_width, shape) -> jnp.ndarray:
return self.impl.random_bits(self.keys, bit_width, shape)

def split(self, num: int) -> 'PRNGKeyArray':
def _split(self, num: int) -> 'PRNGKeyArray':
return PRNGKeyArray(self.impl, self.impl.split(self.keys, num))

def __repr__(self):
base_ndim = len(self.impl.key_shape)
arr_shape = self.keys.shape[:-base_ndim]
arr_shape = self._shape
pp_keys = pp('shape = ') >> pp(arr_shape)
if isinstance(self.impl, PRNGImpl):
pp_impl = self.impl.pprint()
Expand Down
Loading

0 comments on commit def2b8d

Please sign in to comment.