-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error when lowering pallas kernel: 'jaxlib.triton.dialect' has no attribute 'permute' #19990
Comments
Thanks for reporting this, Davis. I sent a PR fixing this. In the meantime you can do
to workaround the issue. One caveat, though: the Triton version we have internally crashes if any of the dot operands is transposed. It is possible that the jaxlib version did not get the problematic upstream changes, but if not, we'd have to wait until the issue is fixed in Triton. I started a discussion in their Slack. |
See #19990 for a reproducer. PiperOrigin-RevId: 610683353
See #19990 for a reproducer. PiperOrigin-RevId: 610683353
See #19990 for a reproducer. PiperOrigin-RevId: 610683353
See #19990 for a reproducer. PiperOrigin-RevId: 610743180
Good to know, I'm sure that just saved me an hour of fruitless debugging. Thanks! |
According to the folks on the Triton Slack channel, dot with transposed inputs is not supported on GPUs below Turing. I was testing on V100, so this seems to explain the crashes. |
@superbobry The monkeypatch works for me, and thanks for making it so much easier to get pallas up and running on GPU |
Thanks for confirming @davisyoshida! I will close this bug as fixed. |
Description
Running the following with JAX 0.4.25 causes an AttributeError:
The error:
This code does work when I use the version installed via the instructions here.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: