From 599a53dac894ea62ffa7a85256c5c5a5617d7bf2 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 7 Feb 2022 07:55:50 -0800 Subject: [PATCH 1/3] Add nd support for Triton-based softmax --- xformers/triton/softmax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xformers/triton/softmax.py b/xformers/triton/softmax.py index e1ccd1dcfa..f6988f7ae3 100644 --- a/xformers/triton/softmax.py +++ b/xformers/triton/softmax.py @@ -40,6 +40,7 @@ def forward(ctx, x, mask, log_outputs, causal): # Handle 2D/3D tensors x_ = x.unsqueeze(0) if x.ndim == 2 else x + x_ = x_.flatten(0, -3) if not x_.is_contiguous(): x_ = x_.contiguous() @@ -92,6 +93,7 @@ def backward(ctx, grad_out): # Handle 2D/3D tensors grad_out_ = grad_out.unsqueeze(0) if grad_out.ndim == 2 else grad_out + grad_out_ = grad_out_.flatten(0, -3) # SPMD launch grid grid_2d = ( From 8b1f26c8fe1f91f95c231202dafde8b56a7712d5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 7 Feb 2022 08:45:24 -0800 Subject: [PATCH 2/3] Add tests --- tests/test_triton_softmax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_triton_softmax.py b/tests/test_triton_softmax.py index c505c301a4..92788f1684 100644 --- a/tests/test_triton_softmax.py +++ b/tests/test_triton_softmax.py @@ -28,6 +28,8 @@ (1, 2048, 2048), (1, 3136, 3136), (1, 4096, 4096), + (2, 2, 384, 384), + (2, 2, 2, 384, 384), ] From a992c5be5678014f820af0ac3af8ea6709a0a69c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 7 Feb 2022 09:28:57 -0800 Subject: [PATCH 3/3] Fix test --- tests/test_triton_softmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_triton_softmax.py b/tests/test_triton_softmax.py index 92788f1684..3c689f41bf 100644 --- a/tests/test_triton_softmax.py +++ b/tests/test_triton_softmax.py @@ -56,7 +56,7 @@ def test_softmax_parity(shape, amp, log, masking, causal, contiguous): X.requires_grad = True X_.requires_grad = True - seq = shape[1] + seq = shape[-1] mask = torch.zeros((seq, seq)).cuda() if masking: mask[torch.rand((seq, seq)) > 0.8] = -float("inf")