diff --git a/tests/test_memory_efficient_attention.py b/tests/test_memory_efficient_attention.py new file mode 100644 index 0000000000..7adaa63692 --- /dev/null +++ b/tests/test_memory_efficient_attention.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math + +import pytest +import torch + +try: + from xformers.triton import mem_efficient_attention + from xformers.triton.utils import gpu_capabilities_older_than_70 + + _triton_available = True +except ImportError: + logging.warning("Triton is not available, some optimizations will not be tested.") + _triton_available = False + + +# Testing odd shapes on purpose +SHAPES = [ + (384, 256), + (1, 384, 128), + (8, 384, 128), + (8, 784, 512), + (16, 1024, 1024), + # (2, 2048, 384), # FIXME + # (4, 3136, 1024), +] + + +def attention_pytorch(q, k, v): + # attention matrix + q = q / math.sqrt(q.size(-1)) + a = q @ k.transpose(-2, -1) + + # softmax + a = torch.softmax(a, dim=-1) + + # retrieval + return a @ v + + +@pytest.mark.skipif(not _triton_available, reason="Triton is not available") +@pytest.mark.skipif( + not _triton_available or gpu_capabilities_older_than_70(), + reason="Triton requires a SM70+ GPU", +) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_mem_efficient_attention_parity(shape, dtype): + q = torch.rand(shape, dtype=dtype, device=torch.device("cuda")) + k = torch.rand(shape, dtype=dtype, device=torch.device("cuda")) + v = torch.rand(shape, dtype=dtype, device=torch.device("cuda")) + + res_pytorch = attention_pytorch(q, k, v) + res_me = mem_efficient_attention.apply(q, k, v, None) + + assert torch.mean(torch.abs(res_pytorch - res_me)) < 0.2 + + # assert torch.allclose(res_pytorch, res_me, rtol=1e-1) FIXME + # TODO: test different sequence lengths for q and k + # TODO: check parity with normal attention + + +@pytest.mark.skipif(not _triton_available, reason="Triton is not available") +@pytest.mark.skipif( + not _triton_available or gpu_capabilities_older_than_70(), + reason="Triton requires a SM70+ GPU", +) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_mem_efficient_attention_memory_use(shape, dtype): + # FW a random bunch of data + q = torch.rand(shape, dtype=dtype, device=torch.device("cuda")) + k = torch.rand(shape, dtype=dtype, device=torch.device("cuda")) + v = torch.rand(shape, dtype=dtype, device=torch.device("cuda")) + + # Vanilla attention + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + _ = attention_pytorch(q, k, v) + torch.cuda.synchronize() + max_memory_torch = torch.cuda.max_memory_allocated() // 2 ** 20 + print(f"Dense - Peak memory use: {max_memory_torch}MB") + + # Mem efficient attention + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + _ = mem_efficient_attention.apply(q, k, v, None) + torch.cuda.synchronize() + + max_memory_me = torch.cuda.max_memory_allocated() // 2 ** 20 + print(f"Memory efficient - Peak memory use: {max_memory_me}MB") + + assert max_memory_me <= max_memory_torch diff --git a/xformers/benchmarks/benchmark_triton_mem_efficient_attention.py b/xformers/benchmarks/benchmark_triton_mem_efficient_attention.py new file mode 100644 index 0000000000..6a63db2c39 --- /dev/null +++ b/xformers/benchmarks/benchmark_triton_mem_efficient_attention.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from typing import Any, Dict + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.triton import mem_efficient_attention + +SHAPES = [ + (8, 256, 512), + (8, 512, 1024), + (4, 1024, 1024), + # (2, 2048, 2048), + # (2, 4096, 4096), + # (1, 2048, 12288), +] + + +def attention_pytorch(q, k, v): + # attention matrix + q = q / math.sqrt(q.size(-1)) + a = q @ k.transpose(-2, -1) + + # softmax + a = torch.softmax(a, dim=-1) + + # retrieval + return a @ v + + +def to_flops_fw( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ms, +): + # q @ kt + flop = q.shape[0] * q.shape[1] * k.shape[1] * (2 * q.shape[2] - 1) + + # normalization + att_shape = q.shape[1] * k.shape[1] + flop += 5 * att_shape # max + substraction + exp + sum + / + + # exp(q @ kt) @ v + flop += v.shape[0] * att_shape * v.shape[1] * 2 + + return flop * 1e-12 / (ms * 1e-3) + + +def bench_mem_efficient_attention(backward: bool): + device = torch.device("cuda") + + for dtype in [ + # torch.float16, + torch.float32, + ]: + results: Dict[str, Any] = {} + + for B, M, K in SHAPES: + k = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward) + q = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward) + v = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward) + + def torch_step(x, y, z): + a = attention_pytorch(x, y, z) + if backward: + torch.norm(a).backward() + return a + + def triton_step(x, y, z): + a = mem_efficient_attention.apply(x, y, z, None) + if backward: + torch.norm(a).backward() + return a + + for testcase in [ + TestCase( + torch_step, + "pytorch - fw{}".format("+bw" if backward else ""), + ), + TestCase( + triton_step, + "triton - fw{}".format("+bw" if backward else ""), + ), + ]: + time = triton.testing.do_bench(lambda: testcase.function(q, k, v))[0] + key = f"B={B}, M={M}, K={K}" + if key not in results: + results[key] = {} + + # Record BW + bandwidth = to_flops_fw(q, k, v, time) + results[key][testcase.name] = f"{bandwidth:.1f}" + + units = "TFlops/s" + pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units=units) + pretty_plot( + results, + title="LayerNorm-FW{}-{}".format("+BW" if backward else "", dtype), + units=units, + dash_key="pytorch", + ) + + +for bw in [False]: # FIXME: needs BW eventually + bench_mem_efficient_attention(bw) diff --git a/xformers/triton/__init__.py b/xformers/triton/__init__.py index 4442204532..621f776dd5 100644 --- a/xformers/triton/__init__.py +++ b/xformers/triton/__init__.py @@ -12,6 +12,7 @@ from .dropout import FusedDropoutBias, dropout # noqa from .fused_linear_layer import FusedLinear # noqa from .layer_norm import FusedLayerNorm, layer_norm # noqa + from .mem_efficient_attention import mem_efficient_attention # noqa from .softmax import log_softmax, softmax # noqa __all__ = [ @@ -22,6 +23,7 @@ "FusedLinear", "FusedLayerNorm", "layer_norm", + "mem_efficient_attention", ] except ImportError: __all__ = [] diff --git a/xformers/triton/k_mem_efficient_attention.py b/xformers/triton/k_mem_efficient_attention.py new file mode 100644 index 0000000000..f4a54b02a3 --- /dev/null +++ b/xformers/triton/k_mem_efficient_attention.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import torch +import triton +import triton.language as tl + +_DEBUG = 0 # 1 to see the kernel PTX assembly +_FUSED_NORMALIZATION = True # FIXME: rounding error, but should work eventually + + +# fmt: off +@triton.jit +def k_me_attention_fw( + OUT, MAXES, WEIGHTS, # out ptr + Q, K, V, # in ptrs + M, N, L, # dims + stride_out_tile, stride_out_m, + stride_maxes, stride_weights, + **META, +): + # fmt: on + + # extract metaparameters + BLOCK_M = META["BLOCK_M"] + BLOCK_N, BLOCK_L = META["BLOCK_N"], META["BLOCK_L"] + FUSED_NORMALIZATION = META["FUSED_NORMALIZATION"] + + scale = META["SCALE"] + + # *within groups*, programs are ordered in a column-major order + # row-id /col-id of the program in the *launch grid* + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + # Compute QKt + # block level matrix multiplication. + # We fetch a block memory block from both inputs, matmul and accumulate, then repeat + qkt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + i = 0 + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rl = tl.arange(0, BLOCK_L) + + for _ in range(L, 0, -BLOCK_L): + rl_i = rl + i * BLOCK_L # keep track of the masking + q_ptrs = Q + rm[:, None] * L + rl_i[None, :] # (BLOCK_M, BLOCK_L) + k_ptrs = K + rn[None, :] * L + rl_i[:, None] # (BLOCK_L, BLOCK_N) + + q = tl.load(q_ptrs, mask=((rm[:, None] < M) & (rl_i[None, :] < L)), other=0.0) # (BLOCK_M, BLOCK_L) + k = tl.load(k_ptrs, mask=((rl_i[:, None] < L) & (rn[None, :] < N)), other=0.0) # (BLOCK_L, BLOCK_N) + + q *= scale # q /= sqrt(dim) + qkt += tl.dot(q, k).to(tl.float32) # (BLOCK_M, BLOCK_N) + + # Update the pointers and counter + i += 1 + + # Pick the local max per row, safeguard the incoming exponential + max_qkt = tl.max(qkt, axis=1) # (BLOCK_M) + max_ptrs = MAXES + pid_n * stride_maxes + rm # (BLOCK_M) + + # Save so that an eventual mismatch can be fixed post-hoc + if FUSED_NORMALIZATION is False: + tl.store(max_ptrs, max_qkt, mask=(rm < M)) + + # Exponentiate the neutralized results + exp_qkt = tl.exp(qkt - max_qkt[:, None]) # (BLOCK_M, BLOCK_N) + + # Softmax normalization constant + weights = tl.sum(exp_qkt, axis=1) # (BLOCK_M) + + if FUSED_NORMALIZATION: + exp_qkt = exp_qkt / weights[:, None] + else: + # Save, global max will be fixed post-hoc + weights_ptrs = WEIGHTS + pid_n * stride_weights + rm + tl.store(weights_ptrs, weights, mask=(rm < M)) + + # Now pre-compute exp_qkt against V. + # We proceed per chunk over L, and save as we go + i = 0 + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rl = tl.arange(0, BLOCK_L) + + v_ptrs = V + rn[:, None] * L + rl[None, :] # (BLOCK_N, BLOCK_L) + out_ptrs = ( + OUT + pid_n * stride_out_tile + rm[:, None] * stride_out_m + rl[None, :] + ) # (BLOCK_M, BLOCK_L) + + for _ in range(L, 0, -BLOCK_L): + rl_i = rl + i * BLOCK_L # Useful to keep track of the masking + + v = tl.load(v_ptrs, mask=((rn[:, None] < N) & (rl_i[None, :] < L)), other=0.0) + qkv = tl.dot(exp_qkt, v).to(tl.float32) # (BLOCK_M, BLOCK_L) + + tl.store(out_ptrs, qkv, mask=(rm[:, None] < M) & (rl_i[None, :] < L)) + + i += 1 + v_ptrs += BLOCK_L + out_ptrs += BLOCK_L + + +def mem_efficient_fw(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: + + assert q.shape[-1] == k.shape[-1] + assert v.shape[-1] == k.shape[-1] + assert k.shape[-2] == v.shape[-2] + assert q.is_contiguous() and k.is_contiguous() and v.is_contiguous() + + q_shape = q.shape + + if q.ndim == 2: + # no batch dimension + q_, k_, v_ = map(lambda x: x.unsqueeze(0), [q, k, v]) + else: + q_, k_, v_ = q, k, v + + B, M, L = q_.shape + B, N, L = k_.shape + + BLOCK_M = 8 + BLOCK_N = min(triton.next_power_of_2(N), 1024) # increase the ceiling to save more memory + BLOCK_L = 8 + + tiles_n = triton.cdiv(N, BLOCK_N) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_M"]), + tiles_n + ) + + out_n = torch.empty((tiles_n, M, L), dtype=q.dtype, device=q.device) + + if not _FUSED_NORMALIZATION: + maxes_n = torch.empty((tiles_n, M), dtype=q.dtype, device=q.device) + weights_n = torch.empty((tiles_n, M), dtype=q.dtype, device=q.device) + else: + assert BLOCK_N >= N, "The buffer is too large over N, we cannot use the fused normalization" + maxes_n = out_n # placeholder, will not be used + weights_n = out_n # placeholder, will not be used + + # FIXME: handle bias + # FIXME: improve on the batch dimension handling ? + qkvs = [] + for i_b in range(B): + + # Use a dedicated kernel to process the attention by blocks + # fmt: off + bin = k_me_attention_fw[grid]( + out_n, maxes_n, weights_n, # outputs + q_[i_b], k_[i_b], v_[i_b], # inputs + M, N, L, # dimensions + out_n.stride(0), out_n.stride(1), maxes_n.stride(0), weights_n.stride(0), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_L=BLOCK_L, + BIAS=False, + SCALE=1. / math.sqrt(L), + FUSED_NORMALIZATION=_FUSED_NORMALIZATION, + num_warps=1 + ) + # fmt: onx + + if _DEBUG: + print(bin.asm['ptx']) + + # Epilogue + if tiles_n > 1: + # There were tiles over the N dimension, + # so the weights were not correct in real time. + + # Let's fix that: + # - collect the real overall max per line + global_max, _ = maxes_n.max(dim=0) + + # - compute the mistake that was done in real time + mismatch = torch.exp(maxes_n - global_max[None, :]) + + # - update the computations to take the consolidated max/weights + out_n *= mismatch.unsqueeze(-1) + weights_n *= mismatch + + out = torch.sum(out_n, dim=0) + weights = torch.sum(weights_n, dim=0) + + qkv = out / weights.unsqueeze(-1) + + else: + # with fused normalization this should just work + if _FUSED_NORMALIZATION: + qkv = out_n.squeeze() + else: + qkv = out_n / weights_n.unsqueeze(-1) + + qkvs.append(qkv) + + return torch.cat(qkvs, dim=0).reshape(q_shape) diff --git a/xformers/triton/mem_efficient_attention.py b/xformers/triton/mem_efficient_attention.py new file mode 100644 index 0000000000..31a36d608e --- /dev/null +++ b/xformers/triton/mem_efficient_attention.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + +from xformers.triton.k_mem_efficient_attention import mem_efficient_fw + + +class mem_efficient_attention(torch.autograd.Function): + """ + Implementing memory efficient attention, from + "Self-attention Does Not Need O(n2) Memory", Rabe et al. + + https://arxiv.org/abs/2112.05682v2 + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, q, k, v, bias): + res = mem_efficient_fw(q, k, v, bias) + + return res + + @staticmethod + @custom_bwd + def backward(ctx, grad_out): + return None, None, None, None