Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when lowering pallas kernel: 'jaxlib.triton.dialect' has no attribute 'permute' #19990

Closed
davisyoshida opened this issue Feb 27, 2024 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@davisyoshida
Copy link
Contributor

davisyoshida commented Feb 27, 2024

Description

Running the following with JAX 0.4.25 causes an AttributeError:

import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl

def do_dot(a, b, out_ref):
    out_ref[0, 0, :, :] = pl.dot(
        a[0],
        b[...],
        trans_b=True,
        allow_tf32=False
    )

@jax.jit
def run(a, b):
    N, block_size, in_dim = a.shape
    out_blocks = b.shape[0] // block_size
    return pl.pallas_call(
        do_dot,
        out_shape=jax.ShapeDtypeStruct(
            shape=(N, out_blocks, block_size, block_size),
            dtype=a.dtype
        ),
        grid=(N, out_blocks),
        in_specs=[
            pl.BlockSpec(
                lambda i, j: (i, 0, 0),
                (1, block_size, in_dim)
            ),
            pl.BlockSpec(
                lambda i, j: (j, 0),
                (block_size, in_dim)
            )
        ],
        out_specs=pl.BlockSpec(
            lambda i, j: (i, j, 0, 0),
            (1, 1, block_size, block_size)
        )
    )(a, b)

N = 32
block_size = 128
out_blocks = 4
in_dim = 32
a = jnp.ones((N, block_size, in_dim))
b = jnp.ones((block_size * out_blocks, in_dim))
print(run(a, b).shape)

The error:

Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 358, in lower_jaxpr_to_triton_ir
    outvals = rule(rule_ctx, *invals, **eqn.params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 1855, in _dot_general_lowering
    b = tt_dialect.permute(b, (1, 0))
        ^^^^^^^^^^^^^^^^^^
AttributeError: module 'jaxlib.triton.dialect' has no attribute 'permute'

This code does work when I use the version installed via the instructions here.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.11.6 (main, Nov 14 2023, 09:36:21) [GCC 13.2.1 20230801]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
@davisyoshida davisyoshida added the bug Something isn't working label Feb 27, 2024
@superbobry
Copy link
Collaborator

Thanks for reporting this, Davis. I sent a PR fixing this. In the meantime you can do

from jaxlib.triton import dialect
dialect.permute = dialect.trans

to workaround the issue.

One caveat, though: the Triton version we have internally crashes if any of the dot operands is transposed. It is possible that the jaxlib version did not get the problematic upstream changes, but if not, we'd have to wait until the issue is fixed in Triton. I started a discussion in their Slack.

copybara-service bot pushed a commit that referenced this issue Feb 27, 2024
See #19990 for a reproducer.

PiperOrigin-RevId: 610683353
copybara-service bot pushed a commit that referenced this issue Feb 27, 2024
See #19990 for a reproducer.

PiperOrigin-RevId: 610683353
copybara-service bot pushed a commit that referenced this issue Feb 27, 2024
See #19990 for a reproducer.

PiperOrigin-RevId: 610683353
copybara-service bot pushed a commit that referenced this issue Feb 27, 2024
See #19990 for a reproducer.

PiperOrigin-RevId: 610743180
@davisyoshida
Copy link
Contributor Author

the Triton version we have internally crashes if any of the dot operands is transposed

Good to know, I'm sure that just saved me an hour of fruitless debugging. Thanks!

@superbobry
Copy link
Collaborator

According to the folks on the Triton Slack channel, dot with transposed inputs is not supported on GPUs below Turing. I was testing on V100, so this seems to explain the crashes.

@davisyoshida
Copy link
Contributor Author

@superbobry The monkeypatch works for me, and thanks for making it so much easier to get pallas up and running on GPU

@superbobry
Copy link
Collaborator

Thanks for confirming @davisyoshida! I will close this bug as fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants