Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Pallas/GPU easier to install #18603

Closed
hawkinsp opened this issue Nov 20, 2023 · 25 comments
Closed

Make Pallas/GPU easier to install #18603

hawkinsp opened this issue Nov 20, 2023 · 25 comments
Assignees
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@hawkinsp
Copy link
Collaborator

Currently it's very difficult to install Pallas and jax_triton, since you have to get compatible versions of everything, and it's very finicky to work out which they are. We should make this easier!

@hawkinsp hawkinsp added enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs labels Nov 20, 2023
@olupton
Copy link
Contributor

olupton commented Dec 11, 2023

I thought it would be useful to post here that the following steps got me a working installation on top of the ghcr.io/nvidia/jax:nightly-2023-12-08 container (from https://github.com/NVIDIA/JAX-Toolbox):

# pip install --no-deps 'jax-triton@git+https://github.com/jax-ml/jax-triton.git@test_588045313' # e4fd5cb21f40c3991a204479c3a1a0e3f0194e91
# cd /opt
# git clone -b llvm-head https://github.com/openai/triton.git # ca78acaf1a6cf68e2af8a68762ec852534ff0610
# cd triton/
# pip install -e python # quite slow
# cd /opt/jax
# git checkout test_588045313 # 56c46f36df6cfaa40253e839203f950481ce97cd
# build-jax.sh # seems to rebuild more than I expected, might be improvable

the test_588045313 branches refer to these PRs:

and IIUC we would prefer to build Triton from https://github.com/openxla/triton. Hopefully once those PRs land the recipe could be simplified a little.

Big ➕ to the sentiment of this issue: this should be easier!

@mehdiataei
Copy link
Contributor

Yes please!

@CoderPat
Copy link

CoderPat commented Dec 15, 2023

+1!
Started to play with https://github.com/google/maxtext but immediately ran into Pallas/Jax-Triton errors and after some effort haven't been able to get a working jax build with Pallas/jax-triton working.

@davisyoshida
Copy link
Contributor

+1 as well, I spent quite a while installing different permutations of versions of things but I couldn't find one that worked.

@vvvm23
Copy link

vvvm23 commented Jan 3, 2024

Another big +1 here, would really love to use Pallas but I cannot find a correct set of commands to install it correctly.

@sharadmv
Copy link
Collaborator

sharadmv commented Jan 4, 2024

Hi everyone, I was able to get a set of mutually compatible version pins. I hope this unblocks folks temporarily while we push on the longer term solution (bundle Triton with jaxlib GPU).

Versions:

triton-nightly==2.1.0.post20231216005823
jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
jaxlib==0.4.24.dev20240103
jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c

Installation commands:

$ pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823
$ pip install -IU --pre jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$ pip install -IU --pre jaxlib[cuda12_pip]==0.4.24.dev20240103 -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
$ pip install --no-deps 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c'

I verified this worked with a simple kernel in colab but haven't yet run more extensive tests. Please let me know if these work for you.

@vvvm23
Copy link

vvvm23 commented Jan 4, 2024

Thanks @sharadmv! Appreciate the quick workaround.

Here is a copy pasteable version (note, you need to be using python 3.9 or 3.10):

pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823
pip install -IU --pre jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install -IU --pre "jaxlib[cuda12_pip]==0.4.24.dev20240103" -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
pip install --no-deps 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c'

EDIT: Though for me, this still didn't work locally, perhaps some interference with my local CUDA?

@sharadmv
Copy link
Collaborator

sharadmv commented Jan 4, 2024

What's your error?

@vvvm23
Copy link

vvvm23 commented Jan 4, 2024

When running the Hello World example in the Quickstart. This is in a fresh environment, followed by the block above, with CUDA12.3.1-2 locally

2024-01-04 21:03:19.992205: W external/xla/xla/service/gpu/command_buffer_scheduling.cc:470] Removed command buffer support for CUBLAS as it's not supported with gpu toolkit version 12020 and driver version 12030. This might negatively impact peformance. To enable CUBLAS support in command buffers use cuda-compat package: https://docs.nvidia.com/deploy/cuda-compatibility/.

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/.pyenv/versions/3.10.12/lib/python3.10/runpy.py:196, in _run_module_as_main()
    195     sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
    197                  "__main__", mod_spec)

File ~/.pyenv/versions/3.10.12/lib/python3.10/runpy.py:86, in _run_code()
     79 run_globals.update(__name__ = mod_name,
     80                    __file__ = fname,
     81                    __cached__ = cached,
   (...)
     84                    __package__ = pkg_name,
     85                    __spec__ = mod_spec)
---> 86 exec(code, run_globals)
     87 return run_globals

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel_launcher.py:17
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/traitlets/config/application.py:1075, in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelapp.py:701, in start()
    700 try:
--> 701     self.io_loop.start()
    702 except KeyboardInterrupt:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/tornado/platform/asyncio.py:205, in start()
    204 def start(self) -> None:
--> 205     self.asyncio_loop.run_forever()

File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/base_events.py:603, in run_forever()
    602 while True:
--> 603     self._run_once()
    604     if self._stopping:

File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/base_events.py:1909, in _run_once()
   1908     else:
-> 1909         handle._run()
   1910 handle = None

File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/events.py:80, in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:534, in dispatch_queue()
    533 try:
--> 534     await self.process_one()
    535 except Exception:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:523, in process_one()
    522         return
--> 523 await dispatch(*args)

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:429, in dispatch_shell()
    428     if inspect.isawaitable(result):
--> 429         await result
    430 except Exception:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:767, in execute_request()
    766 if inspect.isawaitable(reply_content):
--> 767     reply_content = await reply_content
    769 # Flush output before sending the reply.

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/ipkernel.py:429, in do_execute()
    428 if accepts_params["cell_id"]:
--> 429     res = shell.run_cell(
    430         code,
    431         store_history=store_history,
    432         silent=silent,
    433         cell_id=cell_id,
    434     )
    435 else:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3051, in run_cell()
   3050 try:
-> 3051     result = self._run_cell(
   3052         raw_cell, store_history, silent, shell_futures, cell_id
   3053     )
   3054 finally:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3106, in _run_cell()
   3105 try:
-> 3106     result = runner(coro)
   3107 except BaseException as e:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3311, in run_cell_async()
   3308 interactivity = "none" if silent else self.ast_node_interactivity
-> 3311 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3312        interactivity=interactivity, compiler=compiler, result=result)
   3314 self.last_execution_succeeded = not has_raised

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3493, in run_ast_nodes()
   3492     asy = compare(code)
-> 3493 if await self.run_code(code, result, async_=asy):
   3494     return True

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3553, in run_code()
   3552     else:
-> 3553         exec(code_obj, self.user_global_ns, self.user_ns)
   3554 finally:
   3555     # Reset our crash handler in place

Cell In[4], line 1
----> 1 add_vectors(jnp.arange(8), jnp.arange(8))

Cell In[3], line 7, in add_vectors()
      5 @jax.jit
      6 def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
----> 7   return pl.pallas_call(add_vectors_kernel,
      8                         out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype))(x, y)

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:456, in wrapped()
    455 which_linear = (False,) * len(flat_args)
--> 456 out_flat = pallas_call_p.bind(
    457     *consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
    458     in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
    459                     for a in flat_args),
    460     out_shapes=tuple(flat_out_shapes), debug=debug,
    461     interpret=interpret,
    462     grid_mapping=grid_mapping,
    463     input_output_aliases=tuple(input_output_aliases.items()),
    464     **compiler_params)
    465 out = tree_util.tree_unflatten(out_tree, out_flat)

JaxStackTraceBeforeTransformation: ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[4], line 1
----> 1 add_vectors(jnp.arange(8), jnp.arange(8))

    [... skipping hidden 23 frame]

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:415, in _pallas_call_default_lowering(ctx, interpret, *in_nodes, **params)
    411   raise ValueError("Only interpret mode is supported on CPU backend.")
    412 # If we are actually using a specific backend (GPU or TPU), we should have
    413 # already registered backend-specific lowerings. If we get this far, it means
    414 # those backends aren't present.
--> 415 raise ValueError(
    416     f"Cannot lower pallas_call on platform: {platform}. "
    417     "To use Pallas on GPU, please install Triton and JAX-Triton. "
    418     "To use Pallas on TPU, please install Jaxlib TPU and libtpu.")

ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.

@sharadmv
Copy link
Collaborator

sharadmv commented Jan 4, 2024

Can you try running this Python snippet?

import triton
from triton.compiler import code_generator as code_gen
from triton.compiler import compiler as tc
import triton.language as tl
from triton.runtime import autotuner
import triton._C.libtriton.triton as _triton
from triton.common.backend import get_backend
import triton.compiler.backends.cuda as cb

@vvvm23
Copy link

vvvm23 commented Jan 4, 2024

Sorry, I missed a warning during the install:

jax-triton 0.1.4 requires absl-py>=1.4.0, which is not installed.

Installing the specified package, then reinstalling the packages worked :)

@sharadmv
Copy link
Collaborator

sharadmv commented Jan 4, 2024

Ah so the commands don't work exactly?

@vvvm23
Copy link

vvvm23 commented Jan 4, 2024

I think it depends on your environment and whether you have that package preinstalled or not. I think adding absl-py to the jax-triton install line should be sufficient, but I haven't time to validate that immediately.

@hawkinsp
Copy link
Collaborator Author

hawkinsp commented Jan 4, 2024

I'll note that absl-py is not a dependency of JAX itself, but it is a dependency of JAX's tests, so it's likely that many JAX users don't have it installed while most JAX developers do.

@karan-dalal
Copy link

This is great, thank you Sharad! I was able to run the quick start items on an A100. I notice that it's pretty slow to compile, is that normal? Also, I'm getting this warning:

2024-01-04 17:27:28.729972: W external/xla/xla/service/gpu/command_buffer_scheduling.cc:470] Removed command buffer support for CUBLAS as it's not supported with gpu toolkit version 12020 and driver version 12020. This might negatively impact peformance. To enable CUBLAS support in command buffers use cuda-compat package: https://docs.nvidia.com/deploy/cuda-compatibility/.

Is it okay to ignore or will it affect performance?

@davisyoshida
Copy link
Contributor

Worked for me as well, thanks!

@hawkinsp
Copy link
Collaborator Author

hawkinsp commented Jan 5, 2024

@karan-dalal It's safe to ignore that warning. It is fixed (its verbosity was turned down) at head.

@olupton
Copy link
Contributor

olupton commented Jan 10, 2024

We are now publishing containers from https://github.com/NVIDIA/JAX-Toolbox that include the Pallas dependencies. You can find these as special tags of the jax container, for example the latest one is ghcr.io/nvidia/jax:latest-pallas; this is the same as ghcr.io/nvidia/jax:latest but with compatible versions of triton and jax-triton installed as well.

While we are still ironing out bugs here, it might be useful to use an older version such as https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax/168056269?tag=nightly-pallas-2023-12-16.

Hopefully this is an easy way to get started with Pallas on GPU!

@olupton
Copy link
Contributor

olupton commented Feb 9, 2024

We are now publishing containers from https://github.com/NVIDIA/JAX-Toolbox that include the Pallas dependencies. You can find these as special tags of the jax container, for example the latest one is ghcr.io/nvidia/jax:latest-pallas; this is the same as ghcr.io/nvidia/jax:latest but with compatible versions of triton and jax-triton installed as well.

While we are still ironing out bugs here, it might be useful to use an older version such as https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax/168056269?tag=nightly-pallas-2023-12-16.

Hopefully this is an easy way to get started with Pallas on GPU!

Note that there has been a small reorganisation to how these containers are labelled.
The latest one right now is ghcr.io/nvidia/jax:pallas-2024-02-09, and they can be found under this link.

@superbobry
Copy link
Collaborator

Hi everyone, the latest jaxlib version (0.4.25) no longer requires neither jax_triton nor triton packages to compile Pallas kernels on GPU.

Please give it a try and let us know if you run into any issues.

@sharadmv
Copy link
Collaborator

Just wanted to commend the amazing effort Sergei put into this. He enabled a brand new lowering path for Pallas that purely goes through C++ by emitting the MLIR that Triton normally emits. It was a very tricky thing to get right!

@superbobry superbobry added the pallas Issues pertaining to Pallas (GPU or TPU) label Mar 12, 2024
@merrymercy
Copy link
Contributor

merrymercy commented Mar 28, 2024

@superbobry, I am trying to debug some performance degradation errors so I need to run some older versions (Jax v0.4.22-v0.4.24). I've tried several combinations of triton and jax_trition but haven't been able to make them work for Jax v0.4.24. I can make it work for Jax 0.4.25, 0.4.23, and 0.4.21, just not v0.4.24.

Could you inform me which versions of triton and jax_triton are compatible with Jax v0.4.24? I am running the example found at https://jax.readthedocs.io/en/latest/pallas/quickstart.html, and I encountered the following error:

Traceback (most recent call last):
  File "/root/test/test_pallas.py", line 18, in <module>
    add_vectors(jnp.arange(8), jnp.arange(8))
  File "/root/test/test_pallas.py", line 15, in add_vectors
    return pl.pallas_call(add_vectors_kernel,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/pallas_call.py", line 532, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

I often found the error reporting is not accurate and it does not exactly tell the source of the errors.
I tried this commit of jax_triton (jax-ml/jax-triton@28ad476) and the triton commit mentioned in the commit message. It reports the error above.

@superbobry
Copy link
Collaborator

Unfortunately, I'm not sure. I tried finding a working combination for 0.4.24 myself some time ago, and couldn't get it to work.

Do you suspect that the performance degradation is due to changes in Pallas?

@Sea-Snell
Copy link

Sea-Snell commented May 18, 2024

I'm trying to test out the splash attention kernel on GPU and getting the error "NotImplementedError: Scalar prefetch not supported in Triton lowering." My test works fine on TPU-V4 however. I installed the latest version of jax for GPU.

here's the test:

from jax.experimental.pallas.ops.tpu.splash_attention import make_splash_mha, CausalMask, MultiHeadMask, SegmentIds
import jax.numpy as jnp
import jax

splash = make_splash_mha(
    mask=MultiHeadMask([CausalMask((128, 128)) for _ in range(8)]),
    head_shards=1,
    q_seq_shards=1,
)

qs = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)
ks = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)
vs = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)

segment_ids = SegmentIds(
    jnp.asarray(([0]*96)+([1]*32), dtype=jnp.int32),
    jnp.asarray(([0]*96)+([1]*32), dtype=jnp.int32),
)

output = splash(
    q=qs, k=ks, v=vs, segment_ids=segment_ids,
)

print(output.shape)

@superbobry
Copy link
Collaborator

Hey @Sea-Snell, some features of Pallas TPU are indeed unsupported on GPU. Thus, the TPU-specific implementation of splash attention you are using does not currently work on GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests