Skip to content

Commit

Permalink
Compile Triton kernels via XLA by default
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609299269
  • Loading branch information
superbobry authored and jax authors committed Feb 22, 2024
1 parent f314f26 commit 0bf8ddd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.

* Changes
* Pallas now uses XLA instead of the Triton Python APIs to compile Triton
kernels. You can revert to the old behavior by setting the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable to `"0"`.

* Deprecations & Removals
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
solves with `b.ndim > 1`. In the future these will be treated as batched 2D
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2609,8 +2609,8 @@ def _pallas_call_ttir_lowering(


_TRITON_COMPILE_VIA_XLA = config.DEFINE_bool(
"triton_compile_via_xla",
default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", False),
"jax_triton_compile_via_xla",
default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", True),
help="If True, Pallas delegates Triton kernel compilation to XLA.",
)

Expand Down
6 changes: 3 additions & 3 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ jax_test(
"gpu_x32",
"gpu_a100_x32",
],
env = {
"JAX_TRITON_COMPILE_VIA_XLA": "0",
},
shard_count = 4,
deps = [
"//jax:pallas_gpu",
Expand Down Expand Up @@ -118,9 +121,6 @@ jax_test(
"gpu_x32",
"gpu_a100_x32",
],
env = {
"JAX_TRITON_COMPILE_VIA_XLA": "1",
},
shard_count = 4,
deps = [
"//jax:pallas_gpu",
Expand Down

0 comments on commit 0bf8ddd

Please sign in to comment.