Skip to content

Commit

Permalink
Merge pull request #6899 from google:custom-rng
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 391913770
  • Loading branch information
jax authors committed Aug 20, 2021
2 parents c39c093 + 4eb437a commit 705d5ed
Show file tree
Hide file tree
Showing 11 changed files with 998 additions and 542 deletions.
2 changes: 1 addition & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
del _config_module

from ._src.config import (
config, enable_checks, check_tracer_leaks, checking_leaks,
config, enable_checks, check_tracer_leaks, checking_leaks, enable_custom_prng,
debug_nans, debug_infs, log_compiles, default_matmul_precision,
numpy_rank_promotion
)
Expand Down
9 changes: 9 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,15 @@ def update_thread_local_jit_state(**kw):
'computations. Logging is performed with `absl.logging` at WARNING '
'level.'))

enable_custom_prng = config.define_bool_state(
name='jax_enable_custom_prng',
default=False,
help=('Enables an internal upgrade that allows one to define custom '
'pseudo-random number generator implementations. This will '
'be enabled by default in future versions of JAX, at which point '
'disabling it will be considered deprecated. In a version '
'after that the flag will be removed altogether.'))

hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,
Expand Down
Loading

0 comments on commit 705d5ed

Please sign in to comment.