-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ptxas unsupported version: Is CUDA 11.8 support broken in HEAD install? #208
Comments
The issue might be due to Triton depending on CUDA 12. See Line 124 of triton's setup.py. On a cluster with CUDA 12 ready NVIDIA drivers, the issue can be fixed by installing recent nvcc from conda: |
I was able to fix the error on A100 with CUDA 11.8 with the following commands:
|
Note that I encountered this issue as well, see jax-ml/jax#25344. Originally, this comes from the triton wheel (3.1.0) that ship a ptxas binary (built with an llvm+nvptx+cuda 12.4) that may be different from the system one or already installed, as Mark pointed out. Setting the env. variable |
Hey everyone, thanks for developing this library. I'd like to use block sparse matmul with jax, and this project seems to deliver just what we need 👍 Yet, I am having trouble getting
examples/pallas/blocksparse_matmul.py
to run. When installing from HEAD, I run into compatibility problems with ptxas. Help with this would be much appreciated.As far as I understand ptxas version 7.8 is shipped with CUDA 11.8, and 8.0 with CUDA 12.0. As noted below, I installed jaxlib with local CUDA 11.8. Considering the traceback below, I am wondering if jax-triton requires CUDA 12 in its current form? In this case, I would be happy to get a recommendation for jax, and jax-triton commits to install from source.
Traceback
Environment
Working on a university cluster with installed modules for Python 3.10, CUDA 11.8 and cuDNN 8.6. Upon loading the modules, they appear in the $PATH, and $CUDA_HOME is properly set to the directory (e.g. nvcc and ptxas are located here).
I installed jaxlib according to my cuda versions:
pip install "jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230714+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"
Then installed jax-triton as recommended from HEAD.
pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'
Pip list shows (selection):
executing
ptxas --version
yields(Side note: Using a stable jaxlib (0.4.13) and jax-triton stable (0.1.3) yields the already reported import error #157 )
The text was updated successfully, but these errors were encountered: