Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ptxas unsupported version: Is CUDA 11.8 support broken in HEAD install? #208

Open
markschoene opened this issue Jul 27, 2023 · 3 comments
Open

Comments

@markschoene
Copy link

Hey everyone, thanks for developing this library. I'd like to use block sparse matmul with jax, and this project seems to deliver just what we need 👍 Yet, I am having trouble getting examples/pallas/blocksparse_matmul.py to run. When installing from HEAD, I run into compatibility problems with ptxas. Help with this would be much appreciated.

As far as I understand ptxas version 7.8 is shipped with CUDA 11.8, and 8.0 with CUDA 12.0. As noted below, I installed jaxlib with local CUDA 11.8. Considering the traceback below, I am wondering if jax-triton requires CUDA 12 in its current form? In this case, I would be happy to get a recommendation for jax, and jax-triton commits to install from source.

Traceback

2023-07-27 13:17:05.422836: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal   : Unsupported .version 8.0; current version is '7.8'
ptxas fatal   : Ptx assembly aborted due to errors

2023-07-27 13:17:05.422953: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2537] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal   : Unsupported .version 8.0; current version is '7.8'
ptxas fatal   : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.
Traceback (most recent call last):
  File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 227, in <module>
    app.run(main)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 201, in main
    sdd_matmul(x, y, bn=bn, debug=True).block_until_ready()
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 2578, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 382, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 814, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1223, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1207, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1163, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1344, in __call__
    results = self.xla_executable.execute_sharded(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal   : Unsupported .version 8.0; current version is '7.8'
ptxas fatal   : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 227, in <module>
    app.run(main)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 201, in main
    sdd_matmul(x, y, bn=bn, debug=True).block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal   : Unsupported .version 8.0; current version is '7.8'
ptxas fatal   : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.

Environment

Working on a university cluster with installed modules for Python 3.10, CUDA 11.8 and cuDNN 8.6. Upon loading the modules, they appear in the $PATH, and $CUDA_HOME is properly set to the directory (e.g. nvcc and ptxas are located here).

I installed jaxlib according to my cuda versions:

pip install "jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230714+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"

Then installed jax-triton as recommended from HEAD.

pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'

Pip list shows (selection):

jax                               0.4.14
jax-triton                        0.1.4
jaxlib                            0.4.14.dev20230714+cuda11.cudnn86
triton-nightly                    2.1.0.dev20230714011643

executing ptxas --version yields

bash$ ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:31:59_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

(Side note: Using a stable jaxlib (0.4.13) and jax-triton stable (0.1.3) yields the already reported import error #157 )

@markschoene
Copy link
Author

markschoene commented Aug 3, 2023

The issue might be due to Triton depending on CUDA 12. See Line 124 of triton's setup.py. On a cluster with CUDA 12 ready NVIDIA drivers, the issue can be fixed by installing recent nvcc from conda:
conda install cuda-nvcc -c nvidia
and then installing jaxlib, jax and jax-triton.

@haoliuhl
Copy link

haoliuhl commented Aug 4, 2023

I was able to fix the error on A100 with CUDA 11.8 with the following commands:

conda install cuda-nvcc -c nvidia
pip install 'jax_triton[cuda11] @ git+https://github.com/jax-ml/jax-triton@75d47fb3bd320abfc22c98443138ef537404ea45'
pip install 'jax[cuda11_pip] @ git+https://github.com/google/jax@467adf35a615d3fd5635cd233e0a116bf65cc984' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@ccoulombe
Copy link

ccoulombe commented Dec 9, 2024

Note that I encountered this issue as well, see jax-ml/jax#25344.
In the end, a solution given was to pass the correct ptxas to triton_lib.

Originally, this comes from the triton wheel (3.1.0) that ship a ptxas binary (built with an llvm+nvptx+cuda 12.4) that may be different from the system one or already installed, as Mark pointed out.

Setting the env. variable TRITON_PTXAS_PATH=$CUDA_HOME/bin/ptxas fixes it, as it allow triton to favor the installed ptxas and not the bundled one.!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants