-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Pallas/GPU easier to install #18603
Comments
I thought it would be useful to post here that the following steps got me a working installation on top of the
the
and IIUC we would prefer to build Triton from https://github.com/openxla/triton. Hopefully once those PRs land the recipe could be simplified a little. Big ➕ to the sentiment of this issue: this should be easier! |
Yes please! |
+1! |
+1 as well, I spent quite a while installing different permutations of versions of things but I couldn't find one that worked. |
Another big +1 here, would really love to use Pallas but I cannot find a correct set of commands to install it correctly. |
Hi everyone, I was able to get a set of mutually compatible version pins. I hope this unblocks folks temporarily while we push on the longer term solution (bundle Triton with jaxlib GPU). Versions:
Installation commands: $ pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823
$ pip install -IU --pre jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$ pip install -IU --pre jaxlib[cuda12_pip]==0.4.24.dev20240103 -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
$ pip install --no-deps 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c' I verified this worked with a simple kernel in colab but haven't yet run more extensive tests. Please let me know if these work for you. |
Thanks @sharadmv! Appreciate the quick workaround. Here is a copy pasteable version (note, you need to be using python 3.9 or 3.10): pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823
pip install -IU --pre jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install -IU --pre "jaxlib[cuda12_pip]==0.4.24.dev20240103" -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
pip install --no-deps 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c' EDIT: Though for me, this still didn't work locally, perhaps some interference with my local CUDA? |
What's your error? |
When running the Hello World example in the Quickstart. This is in a fresh environment, followed by the block above, with CUDA12.3.1-2 locally
|
Can you try running this Python snippet? import triton
from triton.compiler import code_generator as code_gen
from triton.compiler import compiler as tc
import triton.language as tl
from triton.runtime import autotuner
import triton._C.libtriton.triton as _triton
from triton.common.backend import get_backend
import triton.compiler.backends.cuda as cb |
Sorry, I missed a warning during the install:
Installing the specified package, then reinstalling the packages worked :) |
Ah so the commands don't work exactly? |
I think it depends on your environment and whether you have that package preinstalled or not. I think adding |
I'll note that |
This is great, thank you Sharad! I was able to run the quick start items on an A100. I notice that it's pretty slow to compile, is that normal? Also, I'm getting this warning:
Is it okay to ignore or will it affect performance? |
Worked for me as well, thanks! |
@karan-dalal It's safe to ignore that warning. It is fixed (its verbosity was turned down) at head. |
We are now publishing containers from https://github.com/NVIDIA/JAX-Toolbox that include the Pallas dependencies. You can find these as special tags of the While we are still ironing out bugs here, it might be useful to use an older version such as https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax/168056269?tag=nightly-pallas-2023-12-16. Hopefully this is an easy way to get started with Pallas on GPU! |
Note that there has been a small reorganisation to how these containers are labelled. |
Hi everyone, the latest jaxlib version (0.4.25) no longer requires neither Please give it a try and let us know if you run into any issues. |
Just wanted to commend the amazing effort Sergei put into this. He enabled a brand new lowering path for Pallas that purely goes through C++ by emitting the MLIR that Triton normally emits. It was a very tricky thing to get right! |
@superbobry, I am trying to debug some performance degradation errors so I need to run some older versions (Jax v0.4.22-v0.4.24). I've tried several combinations of triton and jax_trition but haven't been able to make them work for Jax v0.4.24. I can make it work for Jax 0.4.25, 0.4.23, and 0.4.21, just not v0.4.24. Could you inform me which versions of triton and jax_triton are compatible with Jax v0.4.24? I am running the example found at https://jax.readthedocs.io/en/latest/pallas/quickstart.html, and I encountered the following error:
I often found the error reporting is not accurate and it does not exactly tell the source of the errors. |
Unfortunately, I'm not sure. I tried finding a working combination for 0.4.24 myself some time ago, and couldn't get it to work. Do you suspect that the performance degradation is due to changes in Pallas? |
I'm trying to test out the splash attention kernel on GPU and getting the error "NotImplementedError: Scalar prefetch not supported in Triton lowering." My test works fine on TPU-V4 however. I installed the latest version of jax for GPU. here's the test:
|
Hey @Sea-Snell, some features of Pallas TPU are indeed unsupported on GPU. Thus, the TPU-specific implementation of splash attention you are using does not currently work on GPU. |
Currently it's very difficult to install Pallas and
jax_triton
, since you have to get compatible versions of everything, and it's very finicky to work out which they are. We should make this easier!The text was updated successfully, but these errors were encountered: