From 7f3d464aec9991cfcfe29f8e8da4c137c81c2e4c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 23 May 2022 20:14:20 +0200 Subject: [PATCH] Enable dropout in memory-efficient attention (#334) * Merge compute_scaling_coeffs and update_scaling_coeffs into a single function It wasn't needed to break it in two functions to begin with * Add CUDA implementation for dropout * clang-format * Make p be drop probability * Only CUDA supports dropout * Add benchmarks * Remove unused variables * Fix test * Cleanups and comments --- tests/test_mem_eff_attention.py | 151 +++++- .../benchmarks/benchmark_mem_eff_attention.py | 35 +- .../components/attention/csrc/attention.cpp | 6 +- .../attention/csrc/cpu/attention.cpp | 16 +- .../attention/csrc/cuda/attention.cu | 464 +++++++++++++++--- xformers/ops.py | 21 +- 6 files changed, 603 insertions(+), 90 deletions(-) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 9e93182d29..654418839a 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -5,18 +5,23 @@ import pytest import torch +from scipy.stats import binom_test import xformers.ops +cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") _devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] -def ref_attention(q, k, v, attn_bias=None): +def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0): q = q * (1 / q.shape[-1] ** 0.5) - if attn_bias is None: - return (q @ k.transpose(-2, -1)).softmax(-1) @ v - else: - return (q @ k.transpose(-2, -1) + attn_bias).softmax(-1) @ v + attn = q @ k.transpose(-2, -1) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + if drop_mask is not None: + attn = attn * (drop_mask / (1 - p)) + return attn @ v @pytest.mark.parametrize("use_attn_bias", [False, True]) @@ -72,7 +77,9 @@ def test_logsumexp(device, q_len, kv_len, batch_size, k_len): key = torch.randn((batch_size, kv_len, k_len), device=device) * scale value = torch.randn((batch_size, kv_len, k_len), device=device) * scale - _, lse = torch.ops.xformers.efficient_attention(query, key, value, True, None) + _, lse, _, _ = torch.ops.xformers.efficient_attention( + query, key, value, True, None, 0.0 + ) ref_lse = ((query / k_len ** 0.5) @ key.transpose(-2, -1)).logsumexp(-1) assert torch.allclose(lse, ref_lse, atol=2e-4) @@ -133,3 +140,135 @@ def test_memory_efficient_attention_backward( assert torch.allclose( grad_v, value.grad, atol=atol ), f"grad_v doesn't match {(grad_v - value.grad).abs().max()}" + + +def _vec_binom_test(x, n, p): + """ + vectorized implementation of scipy.stats.binom_test + this makes our tests much faster + reference: https://github.com/scipy/scipy/blob/v1.8.0/scipy/stats/_morestats.py#L2609-L2702 + """ + import numpy as np + from scipy.stats import distributions + + x = np.atleast_1d(x) + d = distributions.binom.pmf(x, n, p)[:, None] + rerr = 1 + 1e-7 + # x < p * n case + i = np.arange(np.ceil(p * n), n + 1) + y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1) + pval1 = distributions.binom.cdf(x, n, p) + distributions.binom.sf(n - y, n, p) + + # other case + i = np.arange(np.floor(p * n) + 1) + y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1) + pval2 = distributions.binom.cdf(y - 1, n, p) + distributions.binom.sf(x - 1, n, p) + + pval = np.where(x < p * n, pval1, pval2) + pval = np.minimum(1.0, pval) + return pval + + +@cuda_only +@pytest.mark.parametrize("seed", [42, 124]) +@pytest.mark.parametrize("p", [0.3, 0.7]) +@pytest.mark.parametrize("k_len", [32]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) +@pytest.mark.parametrize("q_len", [2, 33]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_dropout(device, q_len, kv_len, batch_size, k_len, p, seed): + scale = 3 + query = torch.randn((batch_size, q_len, k_len), device=device) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + + attn_bias = None + + torch.manual_seed(seed) + out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, p) + + torch.manual_seed(seed) + out2 = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, p) + + assert torch.allclose(out, out2) + + mask = torch.empty((batch_size, q_len, kv_len), device=device) + + torch.manual_seed(seed) + mask = torch.ops.xformers._temp_dropout(mask, p) + + ref = ref_attention(query, key, value, attn_bias, mask, p) + assert torch.allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}" + + num_trials = 1000 + p_val_tol = 0.0001 + keep_prob = 1 - p + masks = [] + for i in range(num_trials): + mask = torch.ops.xformers._temp_dropout(mask, p) + masks.append(mask.clone().cpu()) + masks = torch.stack(masks, dim=0) + p_value = binom_test(masks.sum(), masks.numel(), p=keep_prob) + assert p_value > p_val_tol, p_value + masks = masks.sum(0).flatten() + p_values = _vec_binom_test(masks, num_trials, p=keep_prob) + assert all(p_values > p_val_tol) + + +@cuda_only +@pytest.mark.parametrize("p", [0.3, 0.7]) +@pytest.mark.parametrize("k_len", [5, 6, 32]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) +@pytest.mark.parametrize("q_len", [2, 33]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_dropout_backward(device, q_len, kv_len, batch_size, k_len, p): + scale = 3 + query = torch.randn((batch_size, q_len, k_len), device=device) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + grad_out = torch.ones_like(query) + + attn_bias = None + + seed = 42 + torch.manual_seed(seed) + out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, p) + + out.backward(grad_out) + + grad_q = query.grad + grad_k = key.grad + grad_v = value.grad + + query.grad = None + key.grad = None + value.grad = None + + mask = torch.empty((batch_size, q_len, kv_len), device=device) + + torch.manual_seed(seed) + mask = torch.ops.xformers._temp_dropout(mask, p) + + ref = ref_attention(query, key, value, attn_bias, mask, p) + ref.backward(grad_out) + + # there is some extra precision loss in the CPU implementation due to an + # extra accumulation step in grad_q, which is not present in the CUDA + # implementation + atol = 5e-4 if device == "cuda" else 6e-4 + assert torch.allclose( + grad_q, query.grad, atol=atol + ), f"grad_q doesn't match {(grad_q - query.grad).abs().max()}" + assert torch.allclose( + grad_k, key.grad, atol=atol + ), f"grad_k doesn't match {(grad_k - key.grad).abs().max()}" + assert torch.allclose( + grad_v, value.grad, atol=atol + ), f"grad_v doesn't match {(grad_v - value.grad).abs().max()}" diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index 9679af0591..28238a23fa 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -14,14 +14,19 @@ import xformers.ops -def ref_attention(q, k, v, attn_bias=None): +def ref_attention(q, k, v, attn_bias=None, p=0.0): q = q * (1.0 / q.shape[-1] ** 0.5) + attn = q @ k.transpose(-2, -1) if attn_bias is None: - return (q @ k.transpose(-2, -1)).softmax(-1) @ v + attn = q @ k.transpose(-2, -1) else: # equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v # but faster, and is what is used in PyTorch now - return torch.baddbmm(attn_bias, q, k.transpose(-2, -1)).softmax(-1) @ v + attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1)) + attn = attn.softmax(-1) + if p > 0: + attn = torch.nn.functional.dropout(attn, p=p) + return attn @ v min_run_time = 2 @@ -32,6 +37,8 @@ def ref_attention(q, k, v, attn_bias=None): itertools.product([1, 8, 32, 256], [127, 128, 512, 513, 1023, 1024], [16, 32]) ) +p = 0.0 + results = [] mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={}) @@ -59,15 +66,17 @@ def benchmark_forward(): rr = ref_attention(q, q, q, attn_bias) assert (r - rr).abs().max() < 1e-5 + del r, rr torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() results.append( benchmark.Timer( - stmt="fn(q, q, q, attn_bias)", + stmt="fn(q, q, q, attn_bias, p)", globals={ "q": q, "attn_bias": attn_bias, + "p": p, "fn": xformers.ops.memory_efficient_attention, }, label=f"attention (use_attn_bias={use_attn_bias})", @@ -87,10 +96,11 @@ def benchmark_forward(): torch.cuda.synchronize() results.append( benchmark.Timer( - stmt="fn(q, q, q, attn_bias)", + stmt="fn(q, q, q, attn_bias, p)", globals={ "q": q, "attn_bias": attn_bias, + "p": p, "fn": ref_attention, }, label=f"attention (use_attn_bias={use_attn_bias})", @@ -106,9 +116,6 @@ def benchmark_forward(): memory_str = f"Memory used: {memory} MB" print("Vanilla", memory_str) - compare = benchmark.Compare(results) - compare.print() - pprint.pprint(mem_use) @@ -141,8 +148,10 @@ def benchmark_backward(): assert ( grad - q.grad ).abs().max() < 1e-5, f"{(grad - q.grad).abs().max()}" + q.grad = None + del r, rr, grad - out = xformers.ops.memory_efficient_attention(q, q, q, attn_bias) + out = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, p) grad = torch.ones_like(q) torch.cuda.reset_peak_memory_stats() @@ -167,7 +176,7 @@ def benchmark_backward(): print("Optimized", memory_str) - out = ref_attention(q, q, q, attn_bias) + out = ref_attention(q, q, q, attn_bias, p) torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() results.append( @@ -190,11 +199,11 @@ def benchmark_backward(): memory_str = f"Memory used: {memory} MB" print("Vanilla", memory_str) - compare = benchmark.Compare(results) - compare.print() - pprint.pprint(mem_use) benchmark_forward() benchmark_backward() + +compare = benchmark.Compare(results) +compare.print() diff --git a/xformers/components/attention/csrc/attention.cpp b/xformers/components/attention/csrc/attention.cpp index 5886e02ee9..c8bc3d60d0 100644 --- a/xformers/components/attention/csrc/attention.cpp +++ b/xformers/components/attention/csrc/attention.cpp @@ -2,7 +2,9 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias) -> (Tensor, Tensor)")); + "xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor? attn_bias) -> (Tensor, Tensor, Tensor)")); + "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::_temp_dropout(Tensor out, float p) -> Tensor")); } diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index 247575dff9..a56fcfc750 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -115,12 +115,13 @@ void attention_kernel( }); } -std::tuple attention( +std::tuple attention( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, bool compute_logsumexp, - const c10::optional& attn_bias_) { + const c10::optional& attn_bias_, + double p) { TORCH_CHECK(query.dim() == key.dim()); TORCH_CHECK(query.dim() == value.dim()); TORCH_CHECK(query.dim() == 3); @@ -155,6 +156,8 @@ std::tuple attention( TORCH_CHECK(key.is_contiguous()); TORCH_CHECK(value.is_contiguous()); + TORCH_CHECK(p == 0, "CPU implementation does not support dropout"); + int64_t B = query.size(0); int64_t M = query.size(1); int64_t K = query.size(2); @@ -177,7 +180,7 @@ std::tuple attention( _tensor_accessor_or_dummy(attn_bias, zeros)); }); - return std::make_tuple(res, logsumexp); + return std::make_tuple(res, logsumexp, 1, 1); } template @@ -268,7 +271,10 @@ std::tuple attention_backward( const at::Tensor& key, const at::Tensor& value, const at::Tensor& logsumexp, - const c10::optional& attn_bias_) { + const c10::optional& attn_bias_, + double p, + int64_t rng_seed, + int64_t rng_offset) { TORCH_CHECK(query.dim() == grad_out.dim()); TORCH_CHECK(query.dim() == key.dim()); TORCH_CHECK(query.dim() == value.dim()); @@ -307,6 +313,8 @@ std::tuple attention_backward( TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); TORCH_CHECK(!grad_out.is_sparse(), "grad_out must be a dense tensor"); + TORCH_CHECK(p == 0, "CPU implementation does not support dropout"); + int64_t B = query.size(0); int64_t M = query.size(1); int64_t N = key.size(1); diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index ef080aa2e2..725fb8d28c 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -4,8 +4,12 @@ #include #include +#include #include #include +#include + +#include #include "sputnik/vector_utils.h" @@ -142,6 +146,36 @@ __device__ void compute_dot( } } +/* +struct RNGMaskGenerator { + + uint64_t seed_; + uint64_t offset_; + int64_t N_; + int64_t global_offset_; + curandStatePhilox4_32_10_t state_; + + __device__ __forceinline__ RNGMaskGenerator (at::PhiloxCudaState philox_args) +{ auto seeds = at:cuda::philox::unpack(philox_args); seed_ = std::get<0>(seeds); + offset_ = std::get<1>(seeds); + } + + __device__ __forceinline__ void set_sublocation(int64_t x, int64_t y) { + int64_t total_offset = global_offset_ + x * N_ + y; + // ideally we would use the code below, but initializing the rng + // takes a significant portion of time. So we instead modify the seed + // so that each thread has a different seed. This has fewer statistical + // guarantees than by doing it properly, but is much faster + // curand_init(seed_, total_offset, offset_, &state_); + curand_init(seed_ + (total_offset << 8) + offset_, 0, 0, &state_); + } + + __device__ __forceinline__ float4 generate() { + return curand_uniform4(&state_); + } + +} +*/ template < typename scalar_t, typename vec_t, @@ -174,6 +208,55 @@ __device__ void compute_final_mult( } } +template +//__device__ __forceinline__ void apply_masking( +__device__ void apply_masking( + scalar_t s_delta[kBlockSizeQ][kBlockSizeK], + at::PhiloxCudaState philox_args, + int64_t global_offset, + int64_t N, + scalar_t p, + int64_t col_offset) { + // strategy: initialize the rng so that each element in the attention + // matrix has its own subsequence, so that we can easily retrieve + // the element during backward + + curandStatePhilox4_32_10_t state; + auto seeds = at::cuda::philox::unpack(philox_args); + + // we will always sample 4 random floats at a time + // as it's more efficient + constexpr int kSampled = 4; + + // because the forward and the backward have different + // access patterns, we round the rng offset so that it's + // a multiple of kSampled, and add the delta needed + int delta = col_offset & (kSampled - 1); + col_offset = col_offset - delta; + +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; + k_item_idx += kSampled) { + int64_t offset = global_offset + q_item_idx * N + k_item_idx + col_offset; + // ideally we would use the code below, but initializing the rng + // takes a significant portion of time. So we instead modify the seed + // so that each thread has a different seed. This has fewer statistical + // guarantees than by doing it properly, but is much faster + curand_init( + std::get<0>(seeds), offset, std::get<1>(seeds) + delta, &state); + // curand_init(std::get<0>(seeds) + (offset << 8) + std::get<1>(seeds), 0, + // 0, &state); + float4 rand = curand_uniform4(&state); + for (int kk = 0; kk < kSampled; kk++) { + if (k_item_idx + kk < kBlockSizeK) + s_delta[q_item_idx][k_item_idx + kk] *= (&rand.x)[kk] < p; + } + } + } +} + template __device__ __forceinline__ void compute_max( scalar_t a[kBlockSizeQ][kBlockSizeK], @@ -193,38 +276,24 @@ __device__ __forceinline__ void compute_max( } template -__device__ __forceinline__ void compute_scaling_coeffs( +__device__ __forceinline__ void compute_and_update_scaling_coeffs( scalar_t m_i[kBlockSizeQ], scalar_t m_prime[kBlockSizeQ], + scalar_t s_prime[kBlockSizeQ], scalar_t si[kBlockSizeQ][kBlockSizeK], scalar_t m_delta[kBlockSizeQ], scalar_t s_delta[kBlockSizeQ][kBlockSizeK]) { #pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { m_delta[q_item_idx] = std::exp(m_prime[q_item_idx] - m_i[q_item_idx]); + m_prime[q_item_idx] = m_i[q_item_idx]; + s_prime[q_item_idx] = s_prime[q_item_idx] * m_delta[q_item_idx]; #pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { s_delta[q_item_idx][k_item_idx] = std::exp(si[q_item_idx][k_item_idx] - m_i[q_item_idx]); -} - -template -__device__ __forceinline__ void update_scaling_coeffs( - scalar_t m_delta[kBlockSizeQ], - scalar_t m_i[kBlockSizeQ], - scalar_t s_delta[kBlockSizeQ][kBlockSizeK], - scalar_t m_prime[kBlockSizeQ], - scalar_t s_prime[kBlockSizeQ]) { -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - s_prime[q_item_idx] = s_prime[q_item_idx] * m_delta[q_item_idx]; -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) s_prime[q_item_idx] += s_delta[q_item_idx][k_item_idx]; - - m_prime[q_item_idx] = m_i[q_item_idx]; + } } } @@ -256,7 +325,12 @@ __device__ void compute_loop( scalar_t s_prime[kBlockSizeQ], vec_t buffer[kBlockSizeQ][BUFFER_SIZE] /*TODO [BUFFER_SIZE limitation]*/, int64_t K, - scalar_t* attn_bias_i) { + scalar_t* attn_bias_i, + at::PhiloxCudaState philox_args, + int64_t global_offset, + int64_t N, + scalar_t p, + int64_t col_offset) { scalar_t si[kBlockSizeQ][kBlockSizeK] = {0}; compute_dot( query_block, key_i, si, K); @@ -271,14 +345,15 @@ __device__ void compute_loop( scalar_t m_delta[kBlockSizeQ]; scalar_t s_delta[kBlockSizeQ][kBlockSizeK]; - compute_scaling_coeffs( - m_i, m_prime, si, m_delta, s_delta); + compute_and_update_scaling_coeffs( + m_i, m_prime, s_prime, si, m_delta, s_delta); + + if (p < 1.0) + apply_masking( + s_delta, philox_args, global_offset, N, p, col_offset); compute_final_mult( value_i, s_delta, m_delta, buffer, K); - - update_scaling_coeffs( - m_delta, m_i, s_delta, m_prime, s_prime); } template < @@ -328,7 +403,10 @@ struct UnrollLoop { vec_t buffer[kBlockSizeQ][BUFFER_SIZE] /*TODO [BUFFER_SIZE limitation]*/, int64_t K, int64_t N, - at::TensorAccessor attn_bias) { + at::TensorAccessor attn_bias, + at::PhiloxCudaState philox_args, + int64_t global_offset, + scalar_t p) { constexpr int64_t step = kBlockSizeK * WARP_SIZE; int64_t l; if (first) { @@ -354,7 +432,12 @@ struct UnrollLoop { s_prime, buffer, K, - attn_bias_i); + attn_bias_i, + philox_args, + global_offset, + N, + p, + l); } } { @@ -375,7 +458,10 @@ struct UnrollLoop { buffer, K, N, - attn_bias); + attn_bias, + philox_args, + global_offset, + p); } } }; @@ -404,7 +490,10 @@ struct UnrollLoop< vec_t buffer[kBlockSizeQ][BUFFER_SIZE] /*TODO [BUFFER_SIZE limitation]*/, int64_t K, int64_t N, - at::TensorAccessor attn_bias) {} + at::TensorAccessor attn_bias, + at::PhiloxCudaState philox_args, + int64_t global_offset, + scalar_t p) {} }; template < @@ -421,7 +510,9 @@ __global__ void attention_kernel( at::PackedTensorAccessor query, at::PackedTensorAccessor key, at::PackedTensorAccessor value, - at::PackedTensorAccessor attn_bias) { + at::PackedTensorAccessor attn_bias, + scalar_t p, + at::PhiloxCudaState philox_args) { constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); static_assert( integerIsPowerOf2(kBlockSizeK * WARP_SIZE), @@ -435,6 +526,8 @@ __global__ void attention_kernel( int64_t query_idx = blockIdx.x * (blockDim.y * kBlockSizeQ) + threadIdx.y * kBlockSizeQ; + int64_t global_offset = batch_idx * M * N + query_idx * N; + if (query_idx >= M) return; @@ -444,7 +537,7 @@ __global__ void attention_kernel( // TODO [BUFFER_SIZE limitation]: the current strategy assumes a // statically-known size for K. Ideally we would like to remove this // limitation in the future, so that any K is supported - vec_t buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; + vec_t buffer[kBlockSizeQ][BUFFER_SIZE] = {}; scalar_t s_prime[kBlockSizeQ] = {0}; scalar_t m_prime[kBlockSizeQ]; for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -476,7 +569,10 @@ __global__ void attention_kernel( buffer, K, N, - attn_bias[batch_idx]); + attn_bias[batch_idx], + philox_args, + global_offset, + p); aggregate_coeffs( m_prime, s_prime, buffer, K); @@ -487,9 +583,10 @@ __global__ void attention_kernel( #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { tmp = buffer[q_item_idx][k]; - iDiv(s_prime[q_item_idx], &tmp); + iDiv(s_prime[q_item_idx] * p, &tmp); - output_block[q_item_idx][k] = tmp; + if (query_idx + q_item_idx < M) + output_block[q_item_idx][k] = tmp; } } @@ -521,7 +618,9 @@ void launch_attention( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, - const at::Tensor& attn_bias) { + const at::Tensor& attn_bias, + float p, + at::PhiloxCudaState rng_engine_inputs) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int64_t B = query.size(0); @@ -561,7 +660,9 @@ void launch_attention( query.packed_accessor(), key.packed_accessor(), value.packed_accessor(), - attn_bias_packed); + attn_bias_packed, + p, + rng_engine_inputs); } else if ((K % 2) == 0) { TORCH_CHECK( K / 2 <= BUFFER_SIZE, @@ -579,7 +680,9 @@ void launch_attention( query.packed_accessor(), key.packed_accessor(), value.packed_accessor(), - attn_bias_packed); + attn_bias_packed, + p, + rng_engine_inputs); } else { TORCH_CHECK( @@ -598,16 +701,19 @@ void launch_attention( query.packed_accessor(), key.packed_accessor(), value.packed_accessor(), - attn_bias_packed); + attn_bias_packed, + p, + rng_engine_inputs); } } -std::tuple attention( +std::tuple attention( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, bool compute_logsumexp, - const c10::optional& attn_bias_) { + const c10::optional& attn_bias_, + double p) { TORCH_CHECK(query.dim() == key.dim()); TORCH_CHECK(query.dim() == value.dim()); TORCH_CHECK(query.dim() == 3); @@ -658,17 +764,42 @@ std::tuple attention( at::Tensor res = at::zeros({B, M, K}, query.options()); at::Tensor logsumexp = at::empty({B, M}, query.options()); + // invert from drop probability to keep probability + p = 1.0 - p; + + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + at::PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + // each element in the attention matrix will have its own subsequence + // in the generator, so the offset is 1 globally + // int64_t counter_offset = p > 0 ? 1 : 0; + int64_t counter_offset = p > 0 ? 4 : 0; + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + // have to pass compute_logsumexp as a template parameter // otherwise there is a slowdown in the kernel... if (compute_logsumexp) { - launch_attention(res, logsumexp, query, key, value, attn_bias); + launch_attention( + res, logsumexp, query, key, value, attn_bias, p, rng_engine_inputs); } else { - launch_attention(res, logsumexp, query, key, value, attn_bias); + launch_attention( + res, logsumexp, query, key, value, attn_bias, p, rng_engine_inputs); } AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(res, logsumexp); + // uint64_t -> int64_t bitwise casting as PyTorch don't support uint64_t + // so just fake it as a int64_t + int64_t seed, offset; + std::memcpy(&seed, &rng_engine_inputs.seed_, sizeof(seed)); + std::memcpy(&offset, &rng_engine_inputs.offset_.val, sizeof(offset)); + + return std::make_tuple(res, logsumexp, seed, offset); } template < @@ -687,7 +818,9 @@ __global__ void attention_backward_grad_v_kernel( at::PackedTensorAccessor value, at::PackedTensorAccessor tmp_sum_i, at::PackedTensorAccessor logsumexp_normalizer, - at::PackedTensorAccessor attn_bias) { + at::PackedTensorAccessor attn_bias, + scalar_t p, + at::PhiloxCudaState philox_args) { int64_t K = query.size(2); int64_t B = query.size(0); int64_t M = query.size(1); @@ -777,6 +910,7 @@ __global__ void attention_backward_grad_v_kernel( } } } + scalar_t one_over_p = 1.0 / p; #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { #pragma unroll @@ -785,10 +919,16 @@ __global__ void attention_backward_grad_v_kernel( std::exp( attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx] + attn_b[k_item_idx]) * - maskQ[q_item_idx] * maskK[k_item_idx]; + maskQ[q_item_idx] * maskK[k_item_idx] * one_over_p; } } + if (p < 1.0) { + int64_t global_offset = batch_idx * M * N + query_idx * N; + apply_masking( + attn_v, philox_args, global_offset, N, p, l); + } + #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { #pragma unroll @@ -848,7 +988,9 @@ __global__ void attention_backward_grad_qk_kernel( at::PackedTensorAccessor value, at::PackedTensorAccessor tmp_sum_i, at::PackedTensorAccessor logsumexp_normalizer, - at::PackedTensorAccessor attn_bias) { + at::PackedTensorAccessor attn_bias, + scalar_t p, + at::PhiloxCudaState philox_args) { int64_t K = query.size(2); int64_t B = query.size(0); int64_t M = query.size(1); @@ -945,6 +1087,7 @@ __global__ void attention_backward_grad_qk_kernel( } } } + scalar_t one_over_p = 1.0 / p; #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { #pragma unroll @@ -957,6 +1100,12 @@ __global__ void attention_backward_grad_qk_kernel( } } + if (p < 1.0) { + int64_t global_offset = batch_idx * M * N + query_idx * N; + apply_masking( + grad_attn_v, philox_args, global_offset, N, p, l); + } + #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { #pragma unroll @@ -964,7 +1113,8 @@ __global__ void attention_backward_grad_qk_kernel( fact[kBlockSizeQ * threadIdx.x + q_item_idx] [kBlockSizeK * threadIdx.y + k_item_idx] = attn_v[q_item_idx][k_item_idx] * scale * - (grad_attn_v[q_item_idx][k_item_idx] - tmp_sum[q_item_idx]); + (grad_attn_v[q_item_idx][k_item_idx] * one_over_p - + tmp_sum[q_item_idx]); } } __syncthreads(); @@ -1025,7 +1175,9 @@ void launch_attention_backward( const at::Tensor& value, const at::Tensor& logsumexp, at::Tensor& tmp_sum_i, - const at::Tensor& attn_bias) { + const at::Tensor& attn_bias, + float p, + at::PhiloxCudaState rng_engine_inputs) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto attn_bias_packed = _packed_tensor_accessor_or_dummy(attn_bias); @@ -1072,7 +1224,9 @@ void launch_attention_backward( value.packed_accessor(), tmp_sum_i.packed_accessor(), logsumexp.packed_accessor(), - attn_bias_packed); + attn_bias_packed, + p, + rng_engine_inputs); } else { attention_backward_grad_v_kernel< scalar_t, @@ -1089,7 +1243,9 @@ void launch_attention_backward( value.packed_accessor(), tmp_sum_i.packed_accessor(), logsumexp.packed_accessor(), - attn_bias_packed); + attn_bias_packed, + p, + rng_engine_inputs); } if ((M % TILE_SIZEQ2 == 0) && (N % TILE_SIZEK2 == 0)) { @@ -1109,7 +1265,9 @@ void launch_attention_backward( value.packed_accessor(), tmp_sum_i.packed_accessor(), logsumexp.packed_accessor(), - attn_bias_packed); + attn_bias_packed, + p, + rng_engine_inputs); } else { attention_backward_grad_qk_kernel< scalar_t, @@ -1127,7 +1285,9 @@ void launch_attention_backward( value.packed_accessor(), tmp_sum_i.packed_accessor(), logsumexp.packed_accessor(), - attn_bias_packed); + attn_bias_packed, + p, + rng_engine_inputs); } } @@ -1137,7 +1297,10 @@ std::tuple attention_backward( const at::Tensor& key, const at::Tensor& value, const at::Tensor& logsumexp, - const c10::optional& attn_bias_) { + const c10::optional& attn_bias_, + double p, + int64_t rng_seed, + int64_t rng_offset) { TORCH_CHECK(query.dim() == grad_out_.dim()); TORCH_CHECK(query.dim() == key.dim()); TORCH_CHECK(query.dim() == value.dim()); @@ -1202,10 +1365,20 @@ std::tuple attention_backward( at::Tensor tmp_sum_i = at::zeros({B, M}, query.options()); + // invert from drop probability to keep probability + p = 1.0 - p; + // using scalar_t = float; // using vec_t = float4; // using vec_t = float; + // get the state where we are supposed to be for the rng + // in orther to sample the same dropout elements + uint64_t seed, offset; + std::memcpy(&seed, &rng_seed, sizeof(seed)); + std::memcpy(&offset, &rng_offset, sizeof(offset)); + at::PhiloxCudaState rng_engine_inputs(seed, offset); + if ((K % 4) == 0) { launch_attention_backward( grad_q, @@ -1217,7 +1390,9 @@ std::tuple attention_backward( value, logsumexp, tmp_sum_i, - attn_bias); + attn_bias, + p, + rng_engine_inputs); } else if ((K % 2) == 0) { launch_attention_backward( grad_q, @@ -1229,7 +1404,9 @@ std::tuple attention_backward( value, logsumexp, tmp_sum_i, - attn_bias); + attn_bias, + p, + rng_engine_inputs); } else { launch_attention_backward( grad_q, @@ -1241,7 +1418,9 @@ std::tuple attention_backward( value, logsumexp, tmp_sum_i, - attn_bias); + attn_bias, + p, + rng_engine_inputs); } AT_CUDA_CHECK(cudaGetLastError()); @@ -1249,6 +1428,173 @@ std::tuple attention_backward( return std::make_tuple(grad_q, grad_k, grad_v); } +// the functions below are only used for testing +// there is a lot of repetition compared to +// the forward code, so this could be refactored +// in the future + +template < + bool first, + typename scalar_t, + typename vec_t, + int kBlockSizeK, + int kBlockSizeQ, + int WARP_SIZE> +struct UnrollLoopForMask { + static __device__ __forceinline__ void eval( + scalar_t* output[kBlockSizeQ], + int64_t N, + int64_t M, + at::PhiloxCudaState philox_args, + int64_t global_offset, + scalar_t p) { + constexpr int64_t step = kBlockSizeK * WARP_SIZE; + int64_t l; + if (first) { + l = threadIdx.x * kBlockSizeK; + } else { + l = N - (N & (2 * step - 1)) + threadIdx.x * kBlockSizeK; + } + // this is equivalent to N - N % step, but faster + // guaranteed to be the same as step is a power of 2 + int64_t end_iter = kBlockSizeK == 1 ? N : N - (N & (step - 1)); + scalar_t s_delta[kBlockSizeQ][kBlockSizeK]; + int64_t query_idx = + blockIdx.x * (blockDim.y * kBlockSizeQ) + threadIdx.y * kBlockSizeQ; + // if (l < end_iter) { + { + for (; l < end_iter; l += step) { + for (int jj = 0; jj < kBlockSizeQ; jj++) { + for (int kk = 0; kk < kBlockSizeK; kk++) { + s_delta[jj][kk] = 1; + } + } + + apply_masking( + s_delta, philox_args, global_offset, N, p, l); + + for (int jj = 0; jj < kBlockSizeQ; jj++) { + for (int kk = 0; kk < kBlockSizeK; kk++) { + if (query_idx + jj < M) + output[jj][l + kk] = s_delta[jj][kk]; + } + } + } + } + { + UnrollLoopForMask< + false, + scalar_t, + vec_t, + kBlockSizeK / 2, + kBlockSizeQ, + WARP_SIZE>::eval(output, N, M, philox_args, global_offset, p); + } + } +}; + +template < + bool first, + typename scalar_t, + typename vec_t, + int kBlockSizeQ, + int WARP_SIZE> +struct UnrollLoopForMask { + static __device__ __forceinline__ void eval( + scalar_t* s_delta[kBlockSizeQ], + int64_t N, + int64_t M, + at::PhiloxCudaState philox_args, + int64_t global_offset, + scalar_t p) {} +}; + +template < + typename scalar_t, + typename vec_t, + int kBlockSizeK, + int kBlockSizeQ, + int WARP_SIZE> +__global__ void dropout_kernel( + at::PackedTensorAccessor output, + scalar_t p, + at::PhiloxCudaState philox_args) { + static_assert( + integerIsPowerOf2(kBlockSizeK * WARP_SIZE), + "kBlockSizeK * WARP_SIZE should be a power of 2"); + int64_t B = output.size(0); + int64_t M = output.size(1); + int64_t N = output.size(2); + + int64_t batch_idx = blockIdx.y; + int64_t query_idx = + blockIdx.x * (blockDim.y * kBlockSizeQ) + threadIdx.y * kBlockSizeQ; + + int64_t global_offset = batch_idx * M * N + query_idx * N; + + if (query_idx >= M) + return; + + scalar_t* output_block[kBlockSizeQ]; + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + int64_t index = query_idx + q_item_idx; + index = index >= M ? M - 1 : index; + output_block[q_item_idx] = output[batch_idx][index].data(); + } + + UnrollLoopForMask< + true, + scalar_t, + vec_t, + kBlockSizeK, + kBlockSizeQ, + WARP_SIZE>::eval(output_block, N, M, philox_args, global_offset, p); +} + +at::Tensor _dropout_mask(at::Tensor output, double p) { + at::cuda::CUDAGuard device_guard(output.device()); + int64_t B = output.size(0); + int64_t M = output.size(1); + int64_t N = output.size(2); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + constexpr int WARP_SIZE = 4; + + constexpr int kBlockSizeK = 32; + constexpr int kBlockSizeQ = 2; + + constexpr int TILE_SIZE = 32; + + dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); + dim3 block(WARP_SIZE, TILE_SIZE / kBlockSizeQ); + + using scalar_t = float; + + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + at::PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + // each element in the attention matrix will have its own subsequence + // in the generator, so the offset is 1 globally + // int64_t counter_offset = p > 0 ? 1 : 0; + int64_t counter_offset = p > 0 ? 4 : 0; + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + + // invert from drop probability to keep probability + p = 1.0 - p; + + dropout_kernel + <<>>( + output.packed_accessor(), p, rng_engine_inputs); + + return output; +} + } // namespace TORCH_LIBRARY_IMPL(xformers, CUDA, m) { @@ -1258,4 +1604,6 @@ TORCH_LIBRARY_IMPL(xformers, CUDA, m) { m.impl( TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward"), TORCH_FN(attention_backward)); + m.impl( + TORCH_SELECTIVE_NAME("xformers::_temp_dropout"), TORCH_FN(_dropout_mask)); } diff --git a/xformers/ops.py b/xformers/ops.py index fc66022450..595cd019b7 100644 --- a/xformers/ops.py +++ b/xformers/ops.py @@ -33,20 +33,26 @@ def masked_matmul(a, b, mask=None): class _MemoryEfficientAttentionOp(torch.autograd.Function): @staticmethod - def forward(ctx, query, key, value, attn_bias): - out, lse = torch.ops.xformers.efficient_attention( - query, key, value, True, attn_bias + def forward(ctx, query, key, value, attn_bias, p): + out, lse, rng_seed, rng_offset = torch.ops.xformers.efficient_attention( + query, key, value, True, attn_bias, p ) ctx.save_for_backward(query, key, value, lse, attn_bias) + ctx.p = p + ctx.rng_seed = rng_seed + ctx.rng_offset = rng_offset return out @staticmethod def backward(ctx, grad): query, key, value, lse, attn_bias = ctx.saved_tensors + p = ctx.p + rng_seed = ctx.rng_seed + rng_offset = ctx.rng_offset grad_q, grad_k, grad_v = torch.ops.xformers.efficient_attention_backward( - grad, query, key, value, lse, attn_bias + grad, query, key, value, lse, attn_bias, p, rng_seed, rng_offset ) - return grad_q, grad_k, grad_v, None + return grad_q, grad_k, grad_v, None, None def memory_efficient_attention( @@ -54,6 +60,7 @@ def memory_efficient_attention( key: torch.Tensor, value: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, + p: float = 0.0, ): """ Implements the memory-efficient attention mechanism following @@ -63,6 +70,6 @@ def memory_efficient_attention( # fast-path that doesn't require computing the logsumexp for backward computation if all(x.requires_grad is False for x in [query, key, value]): return torch.ops.xformers.efficient_attention( - query, key, value, False, attn_bias + query, key, value, False, attn_bias, p )[0] - return _MemoryEfficientAttentionOp.apply(query, key, value, attn_bias) + return _MemoryEfficientAttentionOp.apply(query, key, value, attn_bias, p)