Skip to content

Commit

Permalink
Enable dropout in memory-efficient attention (facebookresearch#334)
Browse files Browse the repository at this point in the history
* Merge compute_scaling_coeffs and update_scaling_coeffs into a single function

It wasn't needed to break it in two functions to begin with

* Add CUDA implementation for dropout

* clang-format

* Make p be drop probability

* Only CUDA supports dropout

* Add benchmarks

* Remove unused variables

* Fix test

* Cleanups and comments
  • Loading branch information
fmassa authored May 23, 2022
1 parent 340650d commit 7f3d464
Show file tree
Hide file tree
Showing 6 changed files with 603 additions and 90 deletions.
151 changes: 145 additions & 6 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@

import pytest
import torch
from scipy.stats import binom_test

import xformers.ops

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


def ref_attention(q, k, v, attn_bias=None):
def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0):
q = q * (1 / q.shape[-1] ** 0.5)
if attn_bias is None:
return (q @ k.transpose(-2, -1)).softmax(-1) @ v
else:
return (q @ k.transpose(-2, -1) + attn_bias).softmax(-1) @ v
attn = q @ k.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
if drop_mask is not None:
attn = attn * (drop_mask / (1 - p))
return attn @ v


@pytest.mark.parametrize("use_attn_bias", [False, True])
Expand Down Expand Up @@ -72,7 +77,9 @@ def test_logsumexp(device, q_len, kv_len, batch_size, k_len):
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

_, lse = torch.ops.xformers.efficient_attention(query, key, value, True, None)
_, lse, _, _ = torch.ops.xformers.efficient_attention(
query, key, value, True, None, 0.0
)
ref_lse = ((query / k_len ** 0.5) @ key.transpose(-2, -1)).logsumexp(-1)

assert torch.allclose(lse, ref_lse, atol=2e-4)
Expand Down Expand Up @@ -133,3 +140,135 @@ def test_memory_efficient_attention_backward(
assert torch.allclose(
grad_v, value.grad, atol=atol
), f"grad_v doesn't match {(grad_v - value.grad).abs().max()}"


def _vec_binom_test(x, n, p):
"""
vectorized implementation of scipy.stats.binom_test
this makes our tests much faster
reference: https://github.com/scipy/scipy/blob/v1.8.0/scipy/stats/_morestats.py#L2609-L2702
"""
import numpy as np
from scipy.stats import distributions

x = np.atleast_1d(x)
d = distributions.binom.pmf(x, n, p)[:, None]
rerr = 1 + 1e-7
# x < p * n case
i = np.arange(np.ceil(p * n), n + 1)
y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1)
pval1 = distributions.binom.cdf(x, n, p) + distributions.binom.sf(n - y, n, p)

# other case
i = np.arange(np.floor(p * n) + 1)
y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1)
pval2 = distributions.binom.cdf(y - 1, n, p) + distributions.binom.sf(x - 1, n, p)

pval = np.where(x < p * n, pval1, pval2)
pval = np.minimum(1.0, pval)
return pval


@cuda_only
@pytest.mark.parametrize("seed", [42, 124])
@pytest.mark.parametrize("p", [0.3, 0.7])
@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 33])
@pytest.mark.parametrize("device", ["cuda"])
def test_dropout(device, q_len, kv_len, batch_size, k_len, p, seed):
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

attn_bias = None

torch.manual_seed(seed)
out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, p)

torch.manual_seed(seed)
out2 = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, p)

assert torch.allclose(out, out2)

mask = torch.empty((batch_size, q_len, kv_len), device=device)

torch.manual_seed(seed)
mask = torch.ops.xformers._temp_dropout(mask, p)

ref = ref_attention(query, key, value, attn_bias, mask, p)
assert torch.allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}"

num_trials = 1000
p_val_tol = 0.0001
keep_prob = 1 - p
masks = []
for i in range(num_trials):
mask = torch.ops.xformers._temp_dropout(mask, p)
masks.append(mask.clone().cpu())
masks = torch.stack(masks, dim=0)
p_value = binom_test(masks.sum(), masks.numel(), p=keep_prob)
assert p_value > p_val_tol, p_value
masks = masks.sum(0).flatten()
p_values = _vec_binom_test(masks, num_trials, p=keep_prob)
assert all(p_values > p_val_tol)


@cuda_only
@pytest.mark.parametrize("p", [0.3, 0.7])
@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 33])
@pytest.mark.parametrize("device", ["cuda"])
def test_dropout_backward(device, q_len, kv_len, batch_size, k_len, p):
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)

grad_out = torch.ones_like(query)

attn_bias = None

seed = 42
torch.manual_seed(seed)
out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, p)

out.backward(grad_out)

grad_q = query.grad
grad_k = key.grad
grad_v = value.grad

query.grad = None
key.grad = None
value.grad = None

mask = torch.empty((batch_size, q_len, kv_len), device=device)

torch.manual_seed(seed)
mask = torch.ops.xformers._temp_dropout(mask, p)

ref = ref_attention(query, key, value, attn_bias, mask, p)
ref.backward(grad_out)

# there is some extra precision loss in the CPU implementation due to an
# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
atol = 5e-4 if device == "cuda" else 6e-4
assert torch.allclose(
grad_q, query.grad, atol=atol
), f"grad_q doesn't match {(grad_q - query.grad).abs().max()}"
assert torch.allclose(
grad_k, key.grad, atol=atol
), f"grad_k doesn't match {(grad_k - key.grad).abs().max()}"
assert torch.allclose(
grad_v, value.grad, atol=atol
), f"grad_v doesn't match {(grad_v - value.grad).abs().max()}"
35 changes: 22 additions & 13 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
import xformers.ops


def ref_attention(q, k, v, attn_bias=None):
def ref_attention(q, k, v, attn_bias=None, p=0.0):
q = q * (1.0 / q.shape[-1] ** 0.5)
attn = q @ k.transpose(-2, -1)
if attn_bias is None:
return (q @ k.transpose(-2, -1)).softmax(-1) @ v
attn = q @ k.transpose(-2, -1)
else:
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
# but faster, and is what is used in PyTorch now
return torch.baddbmm(attn_bias, q, k.transpose(-2, -1)).softmax(-1) @ v
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
attn = attn.softmax(-1)
if p > 0:
attn = torch.nn.functional.dropout(attn, p=p)
return attn @ v


min_run_time = 2
Expand All @@ -32,6 +37,8 @@ def ref_attention(q, k, v, attn_bias=None):
itertools.product([1, 8, 32, 256], [127, 128, 512, 513, 1023, 1024], [16, 32])
)

p = 0.0

results = []
mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={})

Expand Down Expand Up @@ -59,15 +66,17 @@ def benchmark_forward():

rr = ref_attention(q, q, q, attn_bias)
assert (r - rr).abs().max() < 1e-5
del r, rr

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q, attn_bias)",
stmt="fn(q, q, q, attn_bias, p)",
globals={
"q": q,
"attn_bias": attn_bias,
"p": p,
"fn": xformers.ops.memory_efficient_attention,
},
label=f"attention (use_attn_bias={use_attn_bias})",
Expand All @@ -87,10 +96,11 @@ def benchmark_forward():
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q, attn_bias)",
stmt="fn(q, q, q, attn_bias, p)",
globals={
"q": q,
"attn_bias": attn_bias,
"p": p,
"fn": ref_attention,
},
label=f"attention (use_attn_bias={use_attn_bias})",
Expand All @@ -106,9 +116,6 @@ def benchmark_forward():
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)

compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)


Expand Down Expand Up @@ -141,8 +148,10 @@ def benchmark_backward():
assert (
grad - q.grad
).abs().max() < 1e-5, f"{(grad - q.grad).abs().max()}"
q.grad = None
del r, rr, grad

out = xformers.ops.memory_efficient_attention(q, q, q, attn_bias)
out = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, p)
grad = torch.ones_like(q)

torch.cuda.reset_peak_memory_stats()
Expand All @@ -167,7 +176,7 @@ def benchmark_backward():

print("Optimized", memory_str)

out = ref_attention(q, q, q, attn_bias)
out = ref_attention(q, q, q, attn_bias, p)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
Expand All @@ -190,11 +199,11 @@ def benchmark_backward():
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)

compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)


benchmark_forward()
benchmark_backward()

compare = benchmark.Compare(results)
compare.print()
6 changes: 4 additions & 2 deletions xformers/components/attention/csrc/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

TORCH_LIBRARY_FRAGMENT(xformers, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias) -> (Tensor, Tensor)"));
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor? attn_bias) -> (Tensor, Tensor, Tensor)"));
"xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::_temp_dropout(Tensor out, float p) -> Tensor"));
}
16 changes: 12 additions & 4 deletions xformers/components/attention/csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,13 @@ void attention_kernel(
});
}

std::tuple<at::Tensor, at::Tensor> attention(
std::tuple<at::Tensor, at::Tensor, int64_t, int64_t> attention(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
bool compute_logsumexp,
const c10::optional<at::Tensor>& attn_bias_) {
const c10::optional<at::Tensor>& attn_bias_,
double p) {
TORCH_CHECK(query.dim() == key.dim());
TORCH_CHECK(query.dim() == value.dim());
TORCH_CHECK(query.dim() == 3);
Expand Down Expand Up @@ -155,6 +156,8 @@ std::tuple<at::Tensor, at::Tensor> attention(
TORCH_CHECK(key.is_contiguous());
TORCH_CHECK(value.is_contiguous());

TORCH_CHECK(p == 0, "CPU implementation does not support dropout");

int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t K = query.size(2);
Expand All @@ -177,7 +180,7 @@ std::tuple<at::Tensor, at::Tensor> attention(
_tensor_accessor_or_dummy<scalar_t>(attn_bias, zeros));
});

return std::make_tuple(res, logsumexp);
return std::make_tuple(res, logsumexp, 1, 1);
}

template <typename scalar_t>
Expand Down Expand Up @@ -268,7 +271,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> attention_backward(
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& logsumexp,
const c10::optional<at::Tensor>& attn_bias_) {
const c10::optional<at::Tensor>& attn_bias_,
double p,
int64_t rng_seed,
int64_t rng_offset) {
TORCH_CHECK(query.dim() == grad_out.dim());
TORCH_CHECK(query.dim() == key.dim());
TORCH_CHECK(query.dim() == value.dim());
Expand Down Expand Up @@ -307,6 +313,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> attention_backward(
TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor");
TORCH_CHECK(!grad_out.is_sparse(), "grad_out must be a dense tensor");

TORCH_CHECK(p == 0, "CPU implementation does not support dropout");

int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);
Expand Down
Loading

0 comments on commit 7f3d464

Please sign in to comment.