Skip to content

Commit

Permalink
2025-02-02 nightly release (0b2c24f)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Feb 2, 2025
1 parent 7f517ec commit 61c52b6
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 53 deletions.
5 changes: 4 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
[submodule "external/hipify_torch"]
path = external/hipify_torch
url = https://github.com/ROCmSoftwarePlatform/hipify_torch.git
# TODO Using a private copy of cutlass is a temporary mitigation to enable grouped gemm.
# Go back to main cutlass when possible.
[submodule "external/cutlass"]
path = external/cutlass
url = https://github.com/NVIDIA/cutlass.git
url = https://github.com/jwfromm/cutlass.git
branch = FBGEMM
[submodule "external/json"]
path = external/json
url = https://github.com/nlohmann/json.git
2 changes: 1 addition & 1 deletion external/cutlass
5 changes: 0 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
jagged2_softmax,
jagged2_to_padded_dense,
jagged_dense_bmm,
jagged_dense_elementwise_add,
jagged_dense_elementwise_mul_jagged_out,
jagged_flash_attention_basic,
jagged_jagged_bmm,
Expand Down Expand Up @@ -316,10 +315,6 @@
"CUDA": jagged_flash_attention_basic,
"AutogradCUDA": jagged_flash_attention_basic,
},
"sll_jagged_dense_elementwise_add": {
"CUDA": jagged_dense_elementwise_add,
"AutogradCUDA": jagged_dense_elementwise_add,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
Expand Down
12 changes: 10 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,26 @@
# pyre-strict


from fbgemm_gpu.sll.triton.jagged_dense_flash_attention import ( # noqa F401
from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
jagged_dense_elementwise_add,
JaggedDenseAdd, # noqa F401
)
from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
jagged_dense_flash_attention,
JaggedDenseFlashAttention, # noqa F401
)

from fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention import ( # noqa F401
from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
multi_head_jagged_flash_attention,
MultiHeadJaggedFlashAttention, # noqa F401
)

# pyre-ignore[5]
op_registrations = {
"sll_jagged_dense_elementwise_add": {
"CUDA": jagged_dense_elementwise_add,
"AutogradCUDA": jagged_dense_elementwise_add,
},
"sll_jagged_dense_flash_attention": {
"CUDA": jagged_dense_flash_attention,
"AutogradCUDA": jagged_dense_flash_attention,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch

from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import (
dense_to_jagged,
jagged_to_dense,
)


class JaggedDenseAdd(torch.autograd.Function):
@staticmethod
# pyre-fixme
def forward(
ctx, x: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, max_seq_len: int
):
ctx.save_for_backward(x_offsets)
ctx.max_seq_len = max_seq_len
# TODO: what should be the correct behavior when jagged values has length > max seq len?
# current behavior is to not truncate jagged values
# similar for backward grad_output
return dense_to_jagged(
y, [x_offsets], operation_function="add", operation_jagged_values=x
)[0]

@staticmethod
# pyre-fixme
def backward(ctx, grad_output: torch.Tensor):
(offsets,) = ctx.saved_tensors
grad_dense = jagged_to_dense(grad_output, [offsets], [ctx.max_seq_len])
return grad_output, None, grad_dense, None


def jagged_dense_elementwise_add(
x: torch.Tensor,
x_offsets: torch.Tensor,
y: torch.Tensor,
max_seq_len: int,
use_fbgemm_kernel: bool = True,
):
if use_fbgemm_kernel:
return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output(
x, [x_offsets], y
)[0]
else:
return JaggedDenseAdd.apply(x, x_offsets, y, max_seq_len)
43 changes: 0 additions & 43 deletions fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
import triton
import triton.language as tl

from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import (
dense_to_jagged,
jagged_to_dense,
)


def set_block_size(N: int) -> int:
if N > 64:
Expand Down Expand Up @@ -2591,41 +2586,3 @@ def jagged_flash_attention_basic(
)

return jagged_O


class JaggedDenseAdd(torch.autograd.Function):
@staticmethod
# pyre-fixme
def forward(
ctx, x: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, max_seq_len: int
):
ctx.save_for_backward(x_offsets)
ctx.max_seq_len = max_seq_len
# TODO: what should be the correct behavior when jagged values has length > max seq len?
# current behavior is to not truncate jagged values
# similar for backward grad_output
return dense_to_jagged(
y, [x_offsets], operation_function="add", operation_jagged_values=x
)[0]

@staticmethod
# pyre-fixme
def backward(ctx, grad_output: torch.Tensor):
(offsets,) = ctx.saved_tensors
grad_dense = jagged_to_dense(grad_output, [offsets], [ctx.max_seq_len])
return grad_output, None, grad_dense, None


def jagged_dense_elementwise_add(
x: torch.Tensor,
x_offsets: torch.Tensor,
y: torch.Tensor,
max_seq_len: int,
use_fbgemm_kernel: bool = True,
):
if use_fbgemm_kernel:
return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output(
x, [x_offsets], y
)[0]
else:
return JaggedDenseAdd.apply(x, x_offsets, y, max_seq_len)
3 changes: 2 additions & 1 deletion fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

import unittest

import fbgemm_gpu.sll # noqa F401
import hypothesis.strategies as st
import torch
from fbgemm_gpu.sll.triton_sll import jagged_dense_elementwise_add # noqa

from hypothesis import given, settings

from .common import open_source
Expand Down

0 comments on commit 61c52b6

Please sign in to comment.