Skip to content

Commit

Permalink
mem-efficient impl for f16 (#352)
Browse files Browse the repository at this point in the history
Co-authored-by: danthe3rd <danthe3rd>
  • Loading branch information
danthe3rd authored and fmassa committed Aug 10, 2022
1 parent 93a75b7 commit 256f2d4
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 126 deletions.
60 changes: 41 additions & 19 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@

import xformers.ops

torch.backends.cuda.matmul.allow_tf32 = False
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0):
q = q.float()
k = k.float()
v = v.float()

q = q * (1 / q.shape[-1] ** 0.5)
attn = q @ k.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn + attn_bias.float()
attn = attn.softmax(-1)
if drop_mask is not None:
attn = attn * (drop_mask / (1 - p))
Expand All @@ -32,6 +37,7 @@ def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0):
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33, 64, 128])
@pytest.mark.parametrize("q_len", [2, 3, 5, 32, 128])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("dtype", [torch.float, torch.half])
@pytest.mark.parametrize(
"op",
[
Expand All @@ -46,25 +52,31 @@ def test_memory_efficient_attention(
batch_size,
k_len,
use_attn_bias,
dtype,
op: xformers.ops.MemoryEfficientAttentionOp,
):
if (
device not in op.SUPPORTED_DEVICES
or k_len > op.SUPPORTED_MAX_K
or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS)
or dtype not in op.SUPPORTED_DTYPES
):
return # Or `pytest.xfail` ?

scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
attn_bias = None
if use_attn_bias:
attn_bias = torch.randn((batch_size, 1, kv_len), device=device) * scale
attn_bias = (
torch.randn((batch_size, 1, kv_len), device=device, dtype=dtype) * scale
)
attn_bias = attn_bias.expand(batch_size, q_len, kv_len)

out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, op=op)
out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, op=op
).float()
ref = ref_attention(query, key, value, attn_bias)

assert torch.allclose(out, ref, atol=2e-4)
Expand Down Expand Up @@ -93,6 +105,7 @@ def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len):
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 3, 5])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("dtype", [torch.float, torch.half])
@pytest.mark.parametrize(
"op",
[
Expand All @@ -106,18 +119,25 @@ def test_logsumexp(
kv_len,
batch_size,
k_len,
dtype,
op: xformers.ops.MemoryEfficientAttentionOp,
):
if device not in op.SUPPORTED_DEVICES or k_len > op.SUPPORTED_MAX_K:
if (
device not in op.SUPPORTED_DEVICES
or k_len > op.SUPPORTED_MAX_K
or dtype not in op.SUPPORTED_DTYPES
):
return

scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale

_, lse, _, _ = op.FORWARD_OPERATOR(query, key, value, True, None, 0.0)
ref_lse = ((query / k_len**0.5) @ key.transpose(-2, -1)).logsumexp(-1)
ref_lse = (
(query.float() / k_len**0.5) @ key.float().transpose(-2, -1)
).logsumexp(-1)

assert torch.allclose(lse, ref_lse, atol=2e-4)

Expand All @@ -129,6 +149,7 @@ def test_logsumexp(
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33, 64, 128])
@pytest.mark.parametrize("q_len", [2, 3, 5, 32, 128])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("dtype", [torch.float])
@pytest.mark.parametrize(
"op",
[
Expand All @@ -144,23 +165,27 @@ def test_memory_efficient_attention_backward(
k_len,
grad_out_contiguous,
use_attn_bias,
dtype,
op: xformers.ops.MemoryEfficientAttentionOp,
):
if (
device not in op.SUPPORTED_DEVICES
or k_len > op.SUPPORTED_MAX_K
or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS)
or dtype not in op.SUPPORTED_DTYPES
):
return

scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale

attn_bias = None
if use_attn_bias:
attn_bias = torch.randn((batch_size, 1, kv_len), device=device) * scale
attn_bias = (
torch.randn((batch_size, 1, kv_len), device=device, dtype=dtype) * scale
)
attn_bias = attn_bias.expand(batch_size, q_len, kv_len)

query.requires_grad_(True)
Expand All @@ -185,9 +210,6 @@ def test_memory_efficient_attention_backward(
ref = ref_attention(query, key, value, attn_bias)
ref.backward(grad_out)

# there is some extra precision loss in the CPU implementation due to an
# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
atol = 2e-4 + 2e-6 * k_len * kv_len * math.sqrt(batch_size) * math.sqrt(q_len)

# (for mypy)
Expand All @@ -202,7 +224,7 @@ def test_memory_efficient_attention_backward(
]:
assert torch.allclose(
calc_grad, ref_grad, atol=atol
), f"{name} doesn't match (max_diff={(calc_grad - ref_grad).abs().max()} > {atol})"
), f"{name} doesn't match (max_diff={(calc_grad - ref_grad).abs().max()} > {atol}) - dtype={dtype}"


def _vec_binom_test(x, n, p):
Expand Down
1 change: 1 addition & 0 deletions third_party/cutlass/include/cutlass/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ struct numeric_limits<cutlass::half_t> {
static cutlass::half_t round_error() { return cutlass::half_t(0.5f); }

/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); }

/// Returns smallest finite value
Expand Down
57 changes: 37 additions & 20 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


import itertools
import math
from functools import partial

import torch
Expand All @@ -22,7 +23,8 @@ def ref_attention(q, k, v, attn_bias=None, p=0.0):
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
# but faster, and is what is used in PyTorch now
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
attn = attn.softmax(-1)
dtype = attn.dtype
attn = attn.to(torch.float).softmax(-1).to(dtype)
if p > 0:
attn = torch.nn.functional.dropout(attn, p=p)
return attn @ v
Expand All @@ -32,9 +34,7 @@ def ref_attention(q, k, v, attn_bias=None, p=0.0):
device = torch.device("cuda")

NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
SHAPES = list(
itertools.product([1, 8, 32, 256], [127, 128, 512, 513, 1023, 1024], [16, 32])
) + list(itertools.product([32, 256], [128, 512, 1024, 2048], [16, 32, 64, 128, 256]))
SHAPES = list(itertools.product([32, 256], [128, 512, 1024], [16, 32, 128]))
SHAPES = list(set(SHAPES))
SHAPES.sort()

Expand All @@ -56,26 +56,37 @@ def product_dict(**kwargs):
shape=SHAPES,
num_threads=NUM_THREADS,
use_attn_bias=[False, True],
dtype=[torch.half, torch.float],
)
)


def benchmark_forward(shape, num_threads: int, use_attn_bias: bool):
def benchmark_forward(shape, num_threads: int, use_attn_bias: bool, dtype):
B, M, K = shape
q = torch.rand(shape, device=device)
if (
K > op.SUPPORTED_MAX_K
or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS)
or (dtype not in op.SUPPORTED_DTYPES)
):
return
q = torch.rand(shape, device=device, dtype=dtype)
attn_bias = None
if use_attn_bias:
attn_bias = torch.rand(shape[0], 1, shape[1], device=device).expand(
shape[0], shape[1], shape[1]
)
sub_label = f"B={B}, M={M}, K={K}"
attn_bias = torch.rand(
shape[0], 1, shape[1], device=device, dtype=dtype
).expand(shape[0], shape[1], shape[1])
dtype_str = {
torch.half: "f16",
torch.float: "f32",
}[dtype]
sub_label = f"{dtype_str} B={B}, M={M}, K={K}"

if K > op.SUPPORTED_MAX_K or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS):
raise NotImplementedError()
if True:
r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op)
rr = ref_attention(q, q, q, attn_bias)
assert (r - rr).abs().max() < 1e-5
r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op).float()
rr = ref_attention(
q.float(), q.float(), q.float(), attn_bias.float() if attn_bias else None
)
assert (r - rr).abs().max() < 2e-4, (r - rr).abs().max()
del r, rr

yield benchmark.Timer(
Expand Down Expand Up @@ -106,18 +117,23 @@ def benchmark_forward(shape, num_threads: int, use_attn_bias: bool):
)


def benchmark_backward(shape, num_threads: int, use_attn_bias: bool):
def benchmark_backward(shape, num_threads: int, use_attn_bias: bool, dtype):
B, M, K = shape
q = torch.rand(shape, device=device, requires_grad=True)
q = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
attn_bias = None
if use_attn_bias:
attn_bias = torch.rand(shape[0], 1, shape[1], device=device).expand(
shape[0], shape[1], shape[1]
)
sub_label = f"B={B}, M={M}, K={K}"

if K > op.SUPPORTED_MAX_K or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS):
raise NotImplementedError()
if (
K > op.SUPPORTED_MAX_K
or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS)
# only fp32 is supported at the moment
or (dtype not in {torch.float})
):
return
if True:
r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op)
r.backward(torch.ones_like(q))
Expand All @@ -127,7 +143,8 @@ def benchmark_backward(shape, num_threads: int, use_attn_bias: bool):

rr = ref_attention(q, q, q, attn_bias)
rr.backward(torch.ones_like(q))
assert (grad - q.grad).abs().max() < 1e-5, f"{(grad - q.grad).abs().max()}"
atol = 2e-4 + 2e-6 * K * M * math.sqrt(B) * math.sqrt(M)
assert (grad - q.grad).abs().max() < atol, f"{(grad - q.grad).abs().max()}"
q.grad = None
del r, rr, grad

Expand Down
Loading

0 comments on commit 256f2d4

Please sign in to comment.