Skip to content

Commit

Permalink
Re-organize SLL ops, pt 2 (#3643)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3643

X-link: facebookresearch/FBGEMM#719

- Re-organize `jagged_dense_flash_attention`

Reviewed By: sryap

Differential Revision: D68916405

fbshipit-source-id: 688f41ae4b96684697af8538611f9f6a800e7ff2
  • Loading branch information
q10 authored and facebook-github-bot committed Feb 1, 2025
1 parent 477d4d8 commit b95fa22
Show file tree
Hide file tree
Showing 4 changed files with 873 additions and 853 deletions.
5 changes: 0 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
jagged_dense_bmm,
jagged_dense_elementwise_add,
jagged_dense_elementwise_mul_jagged_out,
jagged_dense_flash_attention,
jagged_flash_attention_basic,
jagged_jagged_bmm,
jagged_jagged_bmm_jagged_out,
Expand Down Expand Up @@ -321,10 +320,6 @@
"CUDA": jagged_dense_elementwise_add,
"AutogradCUDA": jagged_dense_elementwise_add,
},
"sll_jagged_dense_flash_attention": {
"CUDA": jagged_dense_flash_attention,
"AutogradCUDA": jagged_dense_flash_attention,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
Expand Down
12 changes: 11 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@
# pyre-strict


from fbgemm_gpu.sll.triton.jagged_dense_flash_attention import ( # noqa F401
jagged_dense_flash_attention,
JaggedDenseFlashAttention, # noqa F401
)

from fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention import ( # noqa F401
multi_head_jagged_flash_attention,
MultiHeadJaggedFlashAttention,
MultiHeadJaggedFlashAttention, # noqa F401
)

# pyre-ignore[5]
op_registrations = {
"sll_jagged_dense_flash_attention": {
"CUDA": jagged_dense_flash_attention,
"AutogradCUDA": jagged_dense_flash_attention,
},
"sll_multi_head_jagged_flash_attention": {
"CUDA": multi_head_jagged_flash_attention,
"AutogradCUDA": multi_head_jagged_flash_attention,
Expand Down
Loading

0 comments on commit b95fa22

Please sign in to comment.