Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Error Lowering to Triton #19222

Closed
karan-dalal opened this issue Jan 5, 2024 · 2 comments
Closed

[Pallas] Error Lowering to Triton #19222

karan-dalal opened this issue Jan 5, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@karan-dalal
Copy link

Description

I'm attempting to test out taking gradients in a Pallas Kernel. Related to discussion 19184. My Pallas installation follows #18603.

Here's my playground code:

from functools import partial

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np

def mlp(x, w1, w2, y):
  hidden1 = jnp.dot(x, w1)
  hidden2 = jax.nn.relu(hidden1)
  output = jnp.dot(hidden2, w2)
  loss = jnp.mean(jnp.square(output - y))
  return loss

def gradient_kernel(x_ref, w1_ref, w2_ref, y_ref, output_ref):
  x, w1, w2, y = x_ref[...], w1_ref[...], w2_ref[...], y_ref[...]
  grad_mlp = jax.grad(mlp, argnums=(1))
  grad_w1 = grad_mlp(x, w1, w2, y)
  output_ref[...] = grad_w1

@jax.jit
def gradient_kernel_test(x: jax.Array, w1: jax.Array, w2: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(gradient_kernel,
                        out_shape=jax.ShapeDtypeStruct(w1.shape, w1.dtype)
                        )(x, w1, w2, y)

key = jax.random.PRNGKey(0)

x = jnp.arange(5).astype(float)
w1 = jax.random.normal(key, shape=(5,5))
w2 = jax.random.normal(key, shape=(5,5))
y = jax.random.normal(key, shape=(5,))

# Use Pallas kernel or standard JAX.
use_kernel = True

if use_kernel:
  grad_w1 = gradient_kernel_test(x, w1, w2, y)
else:
  grad_mlp = jax.grad(mlp, argnums=(1))
  grad_w1 = grad_mlp(x, w1, w2, y)

print(grad_w1)

When I use standard JAX, I can successfully get the gradient with respect to w1. But when I use the Pallas Kernel, I get the following error:

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/juice5/scr5/yusun/data/karan/ttt-gpt/kernels/pallas_playground.py", line 38, in <module>
    grad_w1 = gradient_kernel_test(x, w1, w2, y)
  File "/nlp/scr/yusun/miniconda3/envs/pallas/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1621, in pallas_call_lowering
    compilation_result = compile_jaxpr(
  File "/nlp/scr/yusun/miniconda3/envs/pallas/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1564, in compile_jaxpr
    lowering_result = lower_jaxpr_to_triton_module(
  File "/nlp/scr/yusun/miniconda3/envs/pallas/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 280, in lower_jaxpr_to_triton_module
    () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
  File "/nlp/scr/yusun/miniconda3/envs/pallas/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 339, in lower_jaxpr_to_triton_ir
    raise TritonLoweringException(
jax._src.pallas.triton.lowering.TritonLoweringException: Exception while lowering eqn:
  a:f32[5] = dot_general[
  dimension_numbers=(([0], [0]), ([], []))
  preferred_element_type=float32
] b c
With context:
  TritonLoweringRuleContext(context=TritonModuleContext(name='gradient_kernel', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7f6ee99b8f30>, builder=<triton._C.libtriton.triton.ir.builder object at 0x7f6ee007c900>, module=<triton._C.libtriton.triton.ir.module object at 0x7f6ee007f9c0>, grid_mapping=GridMapping(grid=(), block_mappings=(None, None, None, None, None), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[]), avals_in=[ShapedArray(float32[5]), ShapedArray(float32[5,5])], avals_out=[ShapedArray(float32[5])], block_infos=[None, None])
With inval shapes=[[constexpr[5]], [constexpr[5], constexpr[5]]]
With inval types=[<[5], fp32>, <[5, 5], fp32>]
In jaxpr:
{ lambda ; a:Ref{float32[5]} b:Ref{float32[5,5]} c:Ref{float32[5,5]} d:Ref{float32[5]}
    e:Ref{float32[5,5]}. let
    f:f32[5] <- a[:]
    g:f32[5,5] <- b[:,:]
    h:f32[5,5] <- c[:,:]
    i:f32[5] <- d[:]
    j:f32[5] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] f g
    k:f32[5] = custom_jvp_call[
      call_jaxpr={ lambda ; l:f32[5]. let
          m:f32[5] = pjit[
            name=relu
            jaxpr={ lambda ; n:f32[5]. let o:f32[5] = max n 0.0 in (o,) }
          ] l
        in (m,) }
      jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f6ee0094ca0>
      num_consts=0
      symbolic_zeros=False
    ] j
    p:bool[5] = gt j 0.0
    q:f32[5] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] k h
    r:f32[5] = sub q i
    s:f32[5] = integer_pow[y=1] r
    t:f32[5] = mul 2.0 s
    u:f32[] = div 1.0 5.0
    v:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] u
    w:f32[5] = mul v t
    x:f32[5] = dot_general[
      dimension_numbers=(([0], [1]), ([], []))
      preferred_element_type=float32
    ] w h
    y:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 0.0
    z:bool[5] = eq p True
    ba:f32[5] = select_n z y x
    bb:f32[5,5] = dot_general[
      dimension_numbers=(([], []), ([], []))
      preferred_element_type=float32
    ] ba f
    bc:f32[5,5] = transpose[permutation=(1, 0)] bb
    e[:,:] <- bc
  in () }

Am I implementing my Pallas kernel incorrect? Or are there some operations that are not yet "lower-able" to Triton? I would assume that since gradients are just matrix multiplications / standard JAX operations, they should be able to lower.

Thank you!

What jax/jaxlib version are you using?

0.4.24.dev20240104 0.4.24.dev20240103

Which accelerator(s) are you using?

GPU

Additional system info?

1.26.3 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]

NVIDIA GPU info

Screenshot 2024-01-05 at 11 16 43 AM
@karan-dalal karan-dalal added the bug Something isn't working label Jan 5, 2024
@superbobry
Copy link
Collaborator

superbobry commented Feb 26, 2024

Could you try re-running this with the latest jaxlib version (0.4.25)? Pallas should now do a slightly better job at reporting lowering errors on GPU.

My guess is that the problem is that the issue is that block dimensions are not powers of 2, but you should get an error message saying that.

@superbobry superbobry self-assigned this Feb 26, 2024
@superbobry
Copy link
Collaborator

I will close the issue for now. Please reopen if you are able to reproduce with jaxlib 0.4.25+.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants