Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT][Blocked] Mem efficient attention - FW pass #162

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/test_memory_efficient_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math

import pytest
import torch

try:
from xformers.triton import mem_efficient_attention
from xformers.triton.utils import gpu_capabilities_older_than_70

_triton_available = True
except ImportError:
logging.warning("Triton is not available, some optimizations will not be tested.")
_triton_available = False


# Testing odd shapes on purpose
SHAPES = [
(384, 256),
(1, 384, 128),
(8, 384, 128),
(8, 784, 512),
(16, 1024, 1024),
# (2, 2048, 384), # FIXME
# (4, 3136, 1024),
]


def attention_pytorch(q, k, v):
# attention matrix
q = q / math.sqrt(q.size(-1))
a = q @ k.transpose(-2, -1)

# softmax
a = torch.softmax(a, dim=-1)

# retrieval
return a @ v


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
not _triton_available or gpu_capabilities_older_than_70(),
reason="Triton requires a SM70+ GPU",
)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", [torch.float32])
def test_mem_efficient_attention_parity(shape, dtype):
q = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))
k = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))
v = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))

res_pytorch = attention_pytorch(q, k, v)
res_me = mem_efficient_attention.apply(q, k, v, None)

assert torch.mean(torch.abs(res_pytorch - res_me)) < 0.2

# assert torch.allclose(res_pytorch, res_me, rtol=1e-1) FIXME
# TODO: test different sequence lengths for q and k
# TODO: check parity with normal attention


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
not _triton_available or gpu_capabilities_older_than_70(),
reason="Triton requires a SM70+ GPU",
)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", [torch.float32])
def test_mem_efficient_attention_memory_use(shape, dtype):
# FW a random bunch of data
q = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))
k = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))
v = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))

# Vanilla attention
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

_ = attention_pytorch(q, k, v)
torch.cuda.synchronize()
max_memory_torch = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Dense - Peak memory use: {max_memory_torch}MB")

# Mem efficient attention
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

_ = mem_efficient_attention.apply(q, k, v, None)
torch.cuda.synchronize()

max_memory_me = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Memory efficient - Peak memory use: {max_memory_me}MB")

assert max_memory_me <= max_memory_torch
113 changes: 113 additions & 0 deletions xformers/benchmarks/benchmark_triton_mem_efficient_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import math
from typing import Any, Dict

import torch
import triton

from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.triton import mem_efficient_attention

SHAPES = [
(8, 256, 512),
(8, 512, 1024),
(4, 1024, 1024),
# (2, 2048, 2048),
# (2, 4096, 4096),
# (1, 2048, 12288),
]


def attention_pytorch(q, k, v):
# attention matrix
q = q / math.sqrt(q.size(-1))
a = q @ k.transpose(-2, -1)

# softmax
a = torch.softmax(a, dim=-1)

# retrieval
return a @ v


def to_flops_fw(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
ms,
):
# q @ kt
flop = q.shape[0] * q.shape[1] * k.shape[1] * (2 * q.shape[2] - 1)

# normalization
att_shape = q.shape[1] * k.shape[1]
flop += 5 * att_shape # max + substraction + exp + sum + /

# exp(q @ kt) @ v
flop += v.shape[0] * att_shape * v.shape[1] * 2

return flop * 1e-12 / (ms * 1e-3)


def bench_mem_efficient_attention(backward: bool):
device = torch.device("cuda")

for dtype in [
# torch.float16,
torch.float32,
]:
results: Dict[str, Any] = {}

for B, M, K in SHAPES:
k = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward)
q = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward)
v = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward)

def torch_step(x, y, z):
a = attention_pytorch(x, y, z)
if backward:
torch.norm(a).backward()
return a

def triton_step(x, y, z):
a = mem_efficient_attention.apply(x, y, z, None)
if backward:
torch.norm(a).backward()
return a

for testcase in [
TestCase(
torch_step,
"pytorch - fw{}".format("+bw" if backward else ""),
),
TestCase(
triton_step,
"triton - fw{}".format("+bw" if backward else ""),
),
]:
time = triton.testing.do_bench(lambda: testcase.function(q, k, v))[0]
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}

# Record BW
bandwidth = to_flops_fw(q, k, v, time)
results[key][testcase.name] = f"{bandwidth:.1f}"

units = "TFlops/s"
pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units=units)
pretty_plot(
results,
title="LayerNorm-FW{}-{}".format("+BW" if backward else "", dtype),
units=units,
dash_key="pytorch",
)


for bw in [False]: # FIXME: needs BW eventually
bench_mem_efficient_attention(bw)
2 changes: 2 additions & 0 deletions xformers/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .dropout import FusedDropoutBias, dropout # noqa
from .fused_linear_layer import FusedLinear # noqa
from .layer_norm import FusedLayerNorm, layer_norm # noqa
from .mem_efficient_attention import mem_efficient_attention # noqa
from .softmax import log_softmax, softmax # noqa

__all__ = [
Expand All @@ -22,6 +23,7 @@
"FusedLinear",
"FusedLayerNorm",
"layer_norm",
"mem_efficient_attention",
]
except ImportError:
__all__ = []
Loading