You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm attempting to test out taking gradients in a Pallas Kernel. Related to discussion 19184. My Pallas installation follows #18603.
Here's my playground code:
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
def mlp(x, w1, w2, y):
hidden1 = jnp.dot(x, w1)
hidden2 = jax.nn.relu(hidden1)
output = jnp.dot(hidden2, w2)
loss = jnp.mean(jnp.square(output - y))
return loss
def gradient_kernel(x_ref, w1_ref, w2_ref, y_ref, output_ref):
x, w1, w2, y = x_ref[...], w1_ref[...], w2_ref[...], y_ref[...]
grad_mlp = jax.grad(mlp, argnums=(1))
grad_w1 = grad_mlp(x, w1, w2, y)
output_ref[...] = grad_w1
@jax.jit
def gradient_kernel_test(x: jax.Array, w1: jax.Array, w2: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(gradient_kernel,
out_shape=jax.ShapeDtypeStruct(w1.shape, w1.dtype)
)(x, w1, w2, y)
key = jax.random.PRNGKey(0)
x = jnp.arange(5).astype(float)
w1 = jax.random.normal(key, shape=(5,5))
w2 = jax.random.normal(key, shape=(5,5))
y = jax.random.normal(key, shape=(5,))
# Use Pallas kernel or standard JAX.
use_kernel = True
if use_kernel:
grad_w1 = gradient_kernel_test(x, w1, w2, y)
else:
grad_mlp = jax.grad(mlp, argnums=(1))
grad_w1 = grad_mlp(x, w1, w2, y)
print(grad_w1)
When I use standard JAX, I can successfully get the gradient with respect to w1. But when I use the Pallas Kernel, I get the following error:
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/juice5/scr5/yusun/data/karan/ttt-gpt/kernels/pallas_playground.py", line 38, in <module>
grad_w1 = gradient_kernel_test(x, w1, w2, y)
File "/nlp/scr/yusun/miniconda3/envs/pallas/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1621, in pallas_call_lowering
compilation_result = compile_jaxpr(
File "/nlp/scr/yusun/miniconda3/envs/pallas/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1564, in compile_jaxpr
lowering_result = lower_jaxpr_to_triton_module(
File "/nlp/scr/yusun/miniconda3/envs/pallas/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 280, in lower_jaxpr_to_triton_module
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
File "/nlp/scr/yusun/miniconda3/envs/pallas/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 339, in lower_jaxpr_to_triton_ir
raise TritonLoweringException(
jax._src.pallas.triton.lowering.TritonLoweringException: Exception while lowering eqn:
a:f32[5] = dot_general[
dimension_numbers=(([0], [0]), ([], []))
preferred_element_type=float32
] b c
With context:
TritonLoweringRuleContext(context=TritonModuleContext(name='gradient_kernel', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7f6ee99b8f30>, builder=<triton._C.libtriton.triton.ir.builder object at 0x7f6ee007c900>, module=<triton._C.libtriton.triton.ir.module object at 0x7f6ee007f9c0>, grid_mapping=GridMapping(grid=(), block_mappings=(None, None, None, None, None), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[]), avals_in=[ShapedArray(float32[5]), ShapedArray(float32[5,5])], avals_out=[ShapedArray(float32[5])], block_infos=[None, None])
With inval shapes=[[constexpr[5]], [constexpr[5], constexpr[5]]]
With inval types=[<[5], fp32>, <[5, 5], fp32>]
In jaxpr:
{ lambda ; a:Ref{float32[5]} b:Ref{float32[5,5]} c:Ref{float32[5,5]} d:Ref{float32[5]}
e:Ref{float32[5,5]}. let
f:f32[5] <- a[:]
g:f32[5,5] <- b[:,:]
h:f32[5,5] <- c[:,:]
i:f32[5] <- d[:]
j:f32[5] = dot_general[
dimension_numbers=(([0], [0]), ([], []))
preferred_element_type=float32
] f g
k:f32[5] = custom_jvp_call[
call_jaxpr={ lambda ; l:f32[5]. let
m:f32[5] = pjit[
name=relu
jaxpr={ lambda ; n:f32[5]. let o:f32[5] = max n 0.0 in (o,) }
] l
in (m,) }
jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f6ee0094ca0>
num_consts=0
symbolic_zeros=False
] j
p:bool[5] = gt j 0.0
q:f32[5] = dot_general[
dimension_numbers=(([0], [0]), ([], []))
preferred_element_type=float32
] k h
r:f32[5] = sub q i
s:f32[5] = integer_pow[y=1] r
t:f32[5] = mul 2.0 s
u:f32[] = div 1.0 5.0
v:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] u
w:f32[5] = mul v t
x:f32[5] = dot_general[
dimension_numbers=(([0], [1]), ([], []))
preferred_element_type=float32
] w h
y:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 0.0
z:bool[5] = eq p True
ba:f32[5] = select_n z y x
bb:f32[5,5] = dot_general[
dimension_numbers=(([], []), ([], []))
preferred_element_type=float32
] ba f
bc:f32[5,5] = transpose[permutation=(1, 0)] bb
e[:,:] <- bc
in () }
Am I implementing my Pallas kernel incorrect? Or are there some operations that are not yet "lower-able" to Triton? I would assume that since gradients are just matrix multiplications / standard JAX operations, they should be able to lower.
Description
I'm attempting to test out taking gradients in a Pallas Kernel. Related to discussion 19184. My Pallas installation follows #18603.
Here's my playground code:
When I use standard JAX, I can successfully get the gradient with respect to w1. But when I use the Pallas Kernel, I get the following error:
Am I implementing my Pallas kernel incorrect? Or are there some operations that are not yet "lower-able" to Triton? I would assume that since gradients are just matrix multiplications / standard JAX operations, they should be able to lower.
Thank you!
What jax/jaxlib version are you using?
0.4.24.dev20240104 0.4.24.dev20240103
Which accelerator(s) are you using?
GPU
Additional system info?
1.26.3 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: