diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 550d2cb60f..01e1d46823 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -11,15 +11,20 @@ import xformers.ops +torch.backends.cuda.matmul.allow_tf32 = False cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") _devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0): + q = q.float() + k = k.float() + v = v.float() + q = q * (1 / q.shape[-1] ** 0.5) attn = q @ k.transpose(-2, -1) if attn_bias is not None: - attn = attn + attn_bias + attn = attn + attn_bias.float() attn = attn.softmax(-1) if drop_mask is not None: attn = attn * (drop_mask / (1 - p)) @@ -32,6 +37,7 @@ def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0): @pytest.mark.parametrize("kv_len", [3, 15, 32, 33, 64, 128]) @pytest.mark.parametrize("q_len", [2, 3, 5, 32, 128]) @pytest.mark.parametrize("device", _devices) +@pytest.mark.parametrize("dtype", [torch.float, torch.half]) @pytest.mark.parametrize( "op", [ @@ -46,25 +52,31 @@ def test_memory_efficient_attention( batch_size, k_len, use_attn_bias, + dtype, op: xformers.ops.MemoryEfficientAttentionOp, ): if ( device not in op.SUPPORTED_DEVICES or k_len > op.SUPPORTED_MAX_K or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS) + or dtype not in op.SUPPORTED_DTYPES ): return # Or `pytest.xfail` ? scale = 3 - query = torch.randn((batch_size, q_len, k_len), device=device) * scale - key = torch.randn((batch_size, kv_len, k_len), device=device) * scale - value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale attn_bias = None if use_attn_bias: - attn_bias = torch.randn((batch_size, 1, kv_len), device=device) * scale + attn_bias = ( + torch.randn((batch_size, 1, kv_len), device=device, dtype=dtype) * scale + ) attn_bias = attn_bias.expand(batch_size, q_len, kv_len) - out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, op=op) + out = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias, op=op + ).float() ref = ref_attention(query, key, value, attn_bias) assert torch.allclose(out, ref, atol=2e-4) @@ -93,6 +105,7 @@ def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len): @pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) @pytest.mark.parametrize("q_len", [2, 3, 5]) @pytest.mark.parametrize("device", _devices) +@pytest.mark.parametrize("dtype", [torch.float, torch.half]) @pytest.mark.parametrize( "op", [ @@ -106,18 +119,25 @@ def test_logsumexp( kv_len, batch_size, k_len, + dtype, op: xformers.ops.MemoryEfficientAttentionOp, ): - if device not in op.SUPPORTED_DEVICES or k_len > op.SUPPORTED_MAX_K: + if ( + device not in op.SUPPORTED_DEVICES + or k_len > op.SUPPORTED_MAX_K + or dtype not in op.SUPPORTED_DTYPES + ): return scale = 3 - query = torch.randn((batch_size, q_len, k_len), device=device) * scale - key = torch.randn((batch_size, kv_len, k_len), device=device) * scale - value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale _, lse, _, _ = op.FORWARD_OPERATOR(query, key, value, True, None, 0.0) - ref_lse = ((query / k_len**0.5) @ key.transpose(-2, -1)).logsumexp(-1) + ref_lse = ( + (query.float() / k_len**0.5) @ key.float().transpose(-2, -1) + ).logsumexp(-1) assert torch.allclose(lse, ref_lse, atol=2e-4) @@ -129,6 +149,7 @@ def test_logsumexp( @pytest.mark.parametrize("kv_len", [3, 15, 32, 33, 64, 128]) @pytest.mark.parametrize("q_len", [2, 3, 5, 32, 128]) @pytest.mark.parametrize("device", _devices) +@pytest.mark.parametrize("dtype", [torch.float]) @pytest.mark.parametrize( "op", [ @@ -144,23 +165,27 @@ def test_memory_efficient_attention_backward( k_len, grad_out_contiguous, use_attn_bias, + dtype, op: xformers.ops.MemoryEfficientAttentionOp, ): if ( device not in op.SUPPORTED_DEVICES or k_len > op.SUPPORTED_MAX_K or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS) + or dtype not in op.SUPPORTED_DTYPES ): return scale = 3 - query = torch.randn((batch_size, q_len, k_len), device=device) * scale - key = torch.randn((batch_size, kv_len, k_len), device=device) * scale - value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale attn_bias = None if use_attn_bias: - attn_bias = torch.randn((batch_size, 1, kv_len), device=device) * scale + attn_bias = ( + torch.randn((batch_size, 1, kv_len), device=device, dtype=dtype) * scale + ) attn_bias = attn_bias.expand(batch_size, q_len, kv_len) query.requires_grad_(True) @@ -185,9 +210,6 @@ def test_memory_efficient_attention_backward( ref = ref_attention(query, key, value, attn_bias) ref.backward(grad_out) - # there is some extra precision loss in the CPU implementation due to an - # extra accumulation step in grad_q, which is not present in the CUDA - # implementation atol = 2e-4 + 2e-6 * k_len * kv_len * math.sqrt(batch_size) * math.sqrt(q_len) # (for mypy) @@ -202,7 +224,7 @@ def test_memory_efficient_attention_backward( ]: assert torch.allclose( calc_grad, ref_grad, atol=atol - ), f"{name} doesn't match (max_diff={(calc_grad - ref_grad).abs().max()} > {atol})" + ), f"{name} doesn't match (max_diff={(calc_grad - ref_grad).abs().max()} > {atol}) - dtype={dtype}" def _vec_binom_test(x, n, p): diff --git a/third_party/cutlass/include/cutlass/half.h b/third_party/cutlass/include/cutlass/half.h index 13d7146f8a..87b9648514 100644 --- a/third_party/cutlass/include/cutlass/half.h +++ b/third_party/cutlass/include/cutlass/half.h @@ -622,6 +622,7 @@ struct numeric_limits { static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } /// Returns smallest finite value + CUTLASS_HOST_DEVICE static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } /// Returns smallest finite value diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index a9638add93..997eaf77a6 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -5,6 +5,7 @@ import itertools +import math from functools import partial import torch @@ -22,7 +23,8 @@ def ref_attention(q, k, v, attn_bias=None, p=0.0): # equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v # but faster, and is what is used in PyTorch now attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1)) - attn = attn.softmax(-1) + dtype = attn.dtype + attn = attn.to(torch.float).softmax(-1).to(dtype) if p > 0: attn = torch.nn.functional.dropout(attn, p=p) return attn @ v @@ -32,9 +34,7 @@ def ref_attention(q, k, v, attn_bias=None, p=0.0): device = torch.device("cuda") NUM_THREADS = [1] if device.type == "cuda" else [1, 40] -SHAPES = list( - itertools.product([1, 8, 32, 256], [127, 128, 512, 513, 1023, 1024], [16, 32]) -) + list(itertools.product([32, 256], [128, 512, 1024, 2048], [16, 32, 64, 128, 256])) +SHAPES = list(itertools.product([32, 256], [128, 512, 1024], [16, 32, 128])) SHAPES = list(set(SHAPES)) SHAPES.sort() @@ -56,26 +56,37 @@ def product_dict(**kwargs): shape=SHAPES, num_threads=NUM_THREADS, use_attn_bias=[False, True], + dtype=[torch.half, torch.float], ) ) -def benchmark_forward(shape, num_threads: int, use_attn_bias: bool): +def benchmark_forward(shape, num_threads: int, use_attn_bias: bool, dtype): B, M, K = shape - q = torch.rand(shape, device=device) + if ( + K > op.SUPPORTED_MAX_K + or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS) + or (dtype not in op.SUPPORTED_DTYPES) + ): + return + q = torch.rand(shape, device=device, dtype=dtype) attn_bias = None if use_attn_bias: - attn_bias = torch.rand(shape[0], 1, shape[1], device=device).expand( - shape[0], shape[1], shape[1] - ) - sub_label = f"B={B}, M={M}, K={K}" + attn_bias = torch.rand( + shape[0], 1, shape[1], device=device, dtype=dtype + ).expand(shape[0], shape[1], shape[1]) + dtype_str = { + torch.half: "f16", + torch.float: "f32", + }[dtype] + sub_label = f"{dtype_str} B={B}, M={M}, K={K}" - if K > op.SUPPORTED_MAX_K or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS): - raise NotImplementedError() if True: - r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op) - rr = ref_attention(q, q, q, attn_bias) - assert (r - rr).abs().max() < 1e-5 + r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op).float() + rr = ref_attention( + q.float(), q.float(), q.float(), attn_bias.float() if attn_bias else None + ) + assert (r - rr).abs().max() < 2e-4, (r - rr).abs().max() del r, rr yield benchmark.Timer( @@ -106,9 +117,9 @@ def benchmark_forward(shape, num_threads: int, use_attn_bias: bool): ) -def benchmark_backward(shape, num_threads: int, use_attn_bias: bool): +def benchmark_backward(shape, num_threads: int, use_attn_bias: bool, dtype): B, M, K = shape - q = torch.rand(shape, device=device, requires_grad=True) + q = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) attn_bias = None if use_attn_bias: attn_bias = torch.rand(shape[0], 1, shape[1], device=device).expand( @@ -116,8 +127,13 @@ def benchmark_backward(shape, num_threads: int, use_attn_bias: bool): ) sub_label = f"B={B}, M={M}, K={K}" - if K > op.SUPPORTED_MAX_K or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS): - raise NotImplementedError() + if ( + K > op.SUPPORTED_MAX_K + or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS) + # only fp32 is supported at the moment + or (dtype not in {torch.float}) + ): + return if True: r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op) r.backward(torch.ones_like(q)) @@ -127,7 +143,8 @@ def benchmark_backward(shape, num_threads: int, use_attn_bias: bool): rr = ref_attention(q, q, q, attn_bias) rr.backward(torch.ones_like(q)) - assert (grad - q.grad).abs().max() < 1e-5, f"{(grad - q.grad).abs().max()}" + atol = 2e-4 + 2e-6 * K * M * math.sqrt(B) * math.sqrt(M) + assert (grad - q.grad).abs().max() < atol, f"{(grad - q.grad).abs().max()}" q.grad = None del r, rr, grad diff --git a/xformers/components/attention/csrc/cuda/attention_forward_generic.cu b/xformers/components/attention/csrc/cuda/attention_forward_generic.cu index 4218e1ba10..c6e8d04b19 100644 --- a/xformers/components/attention/csrc/cuda/attention_forward_generic.cu +++ b/xformers/components/attention/csrc/cuda/attention_forward_generic.cu @@ -3,6 +3,8 @@ #include #include +#include + #include #include @@ -19,32 +21,86 @@ #include +// #define FP16_ONLY_USE_TENSORCORES + +// XXX: Maybe CUDA will wake up one day and provide this +template +struct math; + +template <> +struct math { + using scalar_t = cutlass::half_t; + using torch_dtype = half; + static constexpr at::ScalarType kAtScalarType = at::ScalarType::Half; + + static __device__ __forceinline__ cutlass::half_t exp( + cutlass::half_t const& h) { + return cutlass::half_t(hexp(h.to_half())); + } + template + static __host__ at::PackedTensorAccessor packed_accessor( + at::Tensor const& tensor) { + return at::PackedTensorAccessor( + (scalar_t*)(tensor.data_ptr()), + tensor.sizes().data(), + tensor.strides().data()); + } +}; +constexpr at::ScalarType math::kAtScalarType; + +template <> +struct math { + using scalar_t = float; + using torch_dtype = float; + static constexpr at::ScalarType kAtScalarType = at::ScalarType::Float; + + static __device__ __forceinline__ float exp(float const& h) { + return expf(h); + } + template + static __host__ at::PackedTensorAccessor packed_accessor( + at::Tensor const& tensor) { + return tensor.packed_accessor(); + } +}; +constexpr at::ScalarType math::kAtScalarType; + namespace { template constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { return (n + m - 1) / m; } -template +template < + typename scalar_t_, + typename accum_t_ = float, + typename output_t_ = float> struct AttentionKernel { using scalar_t = scalar_t_; + using accum_t = accum_t_; + using output_t = output_t_; - // Blocks - // NOTE: Looks like 16 works better for K <= 64 +// Blocks +// NOTE: Looks like 16 works better for K <= 64 +#ifdef FP16_ONLY_USE_TENSORCORES + static constexpr int64_t kQueriesPerBlock = 64; + static constexpr int64_t kWarpSize = 32; + static constexpr int64_t kNumWarpsPerBlock = 2; +#else static constexpr int64_t kQueriesPerBlock = 32; + static constexpr int64_t kWarpSize = 32; + static constexpr int64_t kNumWarpsPerBlock = 4; +#endif static constexpr int64_t kNumBlocksX = 1; static int64_t getNumBlocksY(int64_t num_queries) { return ceil_div(num_queries, kQueriesPerBlock); } - // Threads - static constexpr int64_t kWarpSize = 32; - static constexpr int64_t kNumWarpsPerBlock = 4; static constexpr int64_t kSiDim1 = kNumWarpsPerBlock * kWarpSize; static void __device__ attention_kernel( - at::TensorAccessor output, - at::TensorAccessor logsumexp, + at::TensorAccessor output, + at::TensorAccessor logsumexp, at::TensorAccessor query, at::TensorAccessor key, at::TensorAccessor value) { @@ -60,18 +116,18 @@ struct AttentionKernel { int32_t num_queries = query.size(0); int32_t K = key.size(1); - scalar_t __shared__ m_prime[kQueriesPerBlock]; - scalar_t __shared__ mi[kQueriesPerBlock]; - scalar_t __shared__ s_prime[kQueriesPerBlock]; - scalar_t __shared__ si[kQueriesPerBlock][kSiDim1]; + __shared__ cutlass::Array m_prime; + __shared__ cutlass::Array mi; + __shared__ cutlass::Array s_prime; + accum_t __shared__ si[kQueriesPerBlock][kSiDim1]; for (int32_t q = 0; q + lane_id < kQueriesPerBlock; q += kWarpSize) { - mi[q + lane_id] = -std::numeric_limits::infinity(); + mi[q + lane_id] = -std::numeric_limits::infinity(); } if (warp_id == 0) { for (int32_t q = 0; q + lane_id < kQueriesPerBlock; q += kWarpSize) { - s_prime[q + lane_id] = 0; - m_prime[q + lane_id] = -std::numeric_limits::infinity(); + s_prime[q + lane_id] = accum_t(0); + m_prime[q + lane_id] = -std::numeric_limits::infinity(); } } @@ -95,19 +151,19 @@ struct AttentionKernel { for (int32_t q = warp_id; q < kQueriesPerBlock; q += kNumWarpsPerBlock) { // parallel warps // 3. Update s_prime - scalar_t sp = 0; - scalar_t my_mi = mi[q]; + accum_t sp = accum_t(0); + accum_t my_mi = mi[q]; static_assert( kNumWarpsPerBlock * kWarpSize % kWarpSize == 0, ".. or add a condition to loop below"); for (int32_t key_id = lane_id; key_id < kNumWarpsPerBlock * kWarpSize; key_id += kWarpSize) { // parallel lanes - scalar_t si_exp = expf(si[q][key_id] - my_mi) * - (key_id + iter_key_start < num_keys); + accum_t si_exp = math::exp(si[q][key_id] - my_mi); + si_exp *= accum_t(key_id + iter_key_start < num_keys); sp += si_exp; si[q][key_id] = si_exp; } - scalar_t m_prime_exp = expf(m_prime[q] - my_mi); + accum_t m_prime_exp = math::exp(m_prime[q] - my_mi); sp = warpSum(sp) + s_prime[q] * m_prime_exp; m_prime[q] = m_prime_exp; @@ -135,14 +191,15 @@ struct AttentionKernel { int32_t(num_queries - warp_id - query_start())); if (iter_col_last > 0 && iter_query_last > 0) { // &output[query_start()][thread_id] - scalar_t* output_line_ptr = + output_t* output_line_ptr = output.data() + (query_start() + warp_id) * output_stride0 + lane_id; for (int32_t q = 0; q < iter_query_last; q += kNumWarpsPerBlock) { // parallel warps - scalar_t line_s_prime = s_prime[q + warp_id]; + auto line_s_prime = s_prime[q + warp_id]; for (int32_t value_col = 0; value_col < iter_col_last; value_col += kWarpSize) { // parallel lanes - output_line_ptr[value_col] /= line_s_prime; + output_line_ptr[value_col] = + output_t(accum_t(output_line_ptr[value_col]) / line_s_prime); } output_line_ptr += output_stride0 * kNumWarpsPerBlock; } @@ -155,7 +212,7 @@ struct AttentionKernel { for (int64_t q = thread_id(); q < iter_query_last; q += kNumWarpsPerBlock * kWarpSize) { *(logsumexp.data() + query_start() + q) = - m_prime[q] + std::log(s_prime[q]); + accum_t(m_prime[q]) + std::log(accum_t(s_prime[q])); } } } @@ -164,9 +221,9 @@ struct AttentionKernel { static __device__ void compute_dot_product_att_value( int32_t const& iter_key_start, at::TensorAccessor& value, - scalar_t m_prime[kQueriesPerBlock], - scalar_t si[kQueriesPerBlock][kSiDim1], - at::TensorAccessor& output) { + cutlass::Array const& m_prime, + accum_t si[kQueriesPerBlock][kSiDim1], + at::TensorAccessor& output) { using ThreadblockShape = cutlass::gemm:: GemmShape; using WarpShape = cutlass::gemm::GemmShape; @@ -177,15 +234,17 @@ struct AttentionKernel { ThreadblockShape, // ThreadblockShape, WarpShape, // WarpShape, InstructionShape, // InstructionShape, - scalar_t, // ElementA, + accum_t, // ElementA, cutlass::layout::RowMajor, // LayoutA, scalar_t, // ElementB, cutlass::layout::RowMajor, // LayoutB, - scalar_t, // ElementC, + accum_t, // ElementC, cutlass::layout::RowMajor, // LayoutC, // Just use `cutlass::arch::OpClassTensorOp` for TensorCores (requires // sm>7.0) - cutlass::arch::OpClassSimt, // OpClass, + cutlass::arch:: + OpClassSimt, // OpClass: + // OpClassSimt/OpClassWmmaTensorOp/OpClassTensorOp 2, // Stages, cutlass::arch::OpMultiplyAdd // Operator, >; @@ -272,14 +331,6 @@ struct AttentionKernel { // Construct thread-scoped matrix multiply Mma mma(shared_storage, thread_id(), warp_id(), lane_id()); - // Output results - // cutlass::gemm::warp::MmaSimtTileIterator, - // cutlass::gemm::Operand::kC, float, cutlass::layout::RowMajor, - // cutlass::gemm::warp::MmaSimtPolicy, - // cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<4, 4, - // 1>>, 1, 1> - typename Mma::Operator::IteratorC iterator_C( - {&output[query_start()][0], output.stride(0)}, lane_id()); auto iterator_C_offset_m = (tb_tile_offset.m() * Mma::WarpCount::kM) + (warp_id() % Mma::WarpCount::kM); auto iterator_C_offset_n = (tb_tile_offset.n() * Mma::WarpCount::kN) + @@ -292,7 +343,6 @@ struct AttentionKernel { cutlass::MatrixCoord( Mma::Operator::IteratorC::Policy::LaneMmaShape::kM, Mma::Operator::IteratorC::Policy::LaneMmaShape::kN); - iterator_C.add_tile_offset({iterator_C_offset_m, iterator_C_offset_n}); typename Mma::FragmentC accum, accum2; // cutlass::Array // TODO: We could avoid all this mess using cutlass's Epilogue concept I @@ -302,7 +352,7 @@ struct AttentionKernel { Mma::WarpGemm::kM * iterator_C_offset_m + lane_offset.row(); const int32_t thread_offset_n = Mma::WarpGemm::kN * iterator_C_offset_n + lane_offset.column(); - scalar_t* output_ptr = &output[query_start()][0]; + output_t* output_ptr = &output[query_start()][0]; const int32_t output_s0 = output.stride(0); const int32_t max_m = output.size(0) - query_start(); const int32_t max_n = output.size(1); @@ -318,9 +368,11 @@ struct AttentionKernel { accum2, thread_offset_m, thread_offset_n, - [&](scalar_t& accum_v, int32_t m, int32_t n) { + [&](typename Mma::FragmentC::reference accum_v, + int32_t m, + int32_t n) { if (m < max_m && n < max_n) { - accum_v = output_ptr[m * output_s0 + n] * m_prime[m]; + accum_v = accum_t(output_ptr[m * output_s0 + n]) * m_prime[m]; } }); int gemm_k_iterations = @@ -335,7 +387,7 @@ struct AttentionKernel { auto it1 = accum.begin(); auto it2 = accum2.begin(); while (it1 != accum.end()) { - *it1 += *it2; + *it1 = *it1 + *it2; ++it1; ++it2; } @@ -344,9 +396,11 @@ struct AttentionKernel { accum, thread_offset_m, thread_offset_n, - [&](scalar_t& accum_v, int32_t const& m, int32_t const& n) { + [&](typename Mma::FragmentC::reference accum_v, + int32_t const& m, + int32_t const& n) { if (m < max_m && n < max_n) { - output_ptr[m * output_s0 + n] = accum_v; + output_ptr[m * output_s0 + n] = output_t(accum_v); } }); } @@ -365,8 +419,6 @@ struct AttentionKernel { using Iterations = typename Iterator::Iterations; using Element = typename Iterator::Element; - static_assert(Fragment::kStorageElements == kQueriesPerBlock); - CUTLASS_PRAGMA_UNROLL for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { // 0 CUTLASS_PRAGMA_UNROLL @@ -396,9 +448,9 @@ struct AttentionKernel { int32_t const& iter_key_start, at::TensorAccessor& query, at::TensorAccessor& key, - scalar_t m_prime[kQueriesPerBlock], - scalar_t si[kQueriesPerBlock][kSiDim1], - scalar_t mi[kQueriesPerBlock]) { + cutlass::Array& m_prime, + accum_t si[kQueriesPerBlock][kSiDim1], + cutlass::Array& mi) { /* Computes the block-matrix product of: (a) query[query_start:query_end, :] @@ -406,12 +458,19 @@ struct AttentionKernel { (b) key[iter_key_start:iter_key_start + kNumWarpsPerBlock * kWarpSize] and stores that into `si` */ - +#ifdef FP16_ONLY_USE_TENSORCORES + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using OpClass = cutlass::arch::OpClassTensorOp; // OpClassWmmaTensorOp? +#else using ThreadblockShape = cutlass::gemm:: GemmShape; using WarpShape = cutlass::gemm::GemmShape; using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - + using OpClass = cutlass::arch::OpClassSimt; +#endif using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, // ThreadblockShape, WarpShape, // WarpShape, @@ -420,11 +479,9 @@ struct AttentionKernel { cutlass::layout::RowMajor, // LayoutA, scalar_t, // ElementB, cutlass::layout::ColumnMajor, // LayoutB, - scalar_t, // ElementC, + accum_t, // ElementC, cutlass::layout::RowMajor, // LayoutC, - // Just use `cutlass::arch::OpClassTensorOp` for TensorCores (requires - // sm>7.0) - cutlass::arch::OpClassSimt, // OpClass, + OpClass, 2, // Stages, cutlass::arch::OpMultiplyAdd // Operator, >; @@ -533,14 +590,14 @@ struct AttentionKernel { // 2. Update `mi` int64_t num_keys = key.size(0); - scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); + accum_t scale = accum_t(1.0 / std::sqrt(float(K))); static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0); for (int16_t q = 0; q < kQueriesPerBlock; q += kNumWarpsPerBlock) { // parallel warps if (query_start() + q + warp_id() >= num_queries) { continue; } - scalar_t currentMax = m_prime[q + warp_id()]; + accum_t currentMax = m_prime[q + warp_id()]; CUTLASS_PRAGMA_UNROLL for (int64_t key_id = 0; key_id < kSiDim1; key_id += kWarpSize) { // parallel lanes @@ -549,7 +606,7 @@ struct AttentionKernel { } // TODO: Scaling could be done as part of an epilogue // in the cutlass calculation above - scalar_t dot_product = si[q + warp_id()][key_id + lane_id()]; + accum_t dot_product = si[q + warp_id()][key_id + lane_id()]; dot_product *= scale; si[q + warp_id()][key_id + lane_id()] = dot_product; currentMax = std::max(currentMax, dot_product); @@ -560,17 +617,19 @@ struct AttentionKernel { } } - static __device__ __forceinline__ scalar_t warpMax(scalar_t val) { + static __device__ __forceinline__ accum_t warpMax(accum_t val) { for (int stride = kWarpSize / 2; stride > 0; stride >>= 1) { - scalar_t tmp = __shfl_xor_sync(0xffffffff, val, stride, kWarpSize); + accum_t tmp = + accum_t(__shfl_xor_sync(0xffffffff, val, stride, kWarpSize)); val = tmp > val ? tmp : val; } return val; } - static __device__ __forceinline__ scalar_t warpSum(scalar_t val) { + static __device__ __forceinline__ accum_t warpSum(accum_t val) { for (int stride = kWarpSize / 2; stride > 0; stride >>= 1) { - scalar_t tmp = __shfl_xor_sync(0xffffffff, val, stride, kWarpSize); + accum_t tmp = + accum_t(__shfl_xor_sync(0xffffffff, val, stride, kWarpSize)); val += tmp; } return val; @@ -599,8 +658,8 @@ __global__ void __launch_bounds__( // number of resident blocks per multiprocessor 12 / AK::kNumWarpsPerBlock) attention_kernel_batched( - at::PackedTensorAccessor output, - at::PackedTensorAccessor logsumexp, + at::PackedTensorAccessor output, + at::PackedTensorAccessor logsumexp, at::PackedTensorAccessor query, at::PackedTensorAccessor key, at::PackedTensorAccessor value) { @@ -644,12 +703,6 @@ efficient_attention_forward_generic( TORCH_CHECK(key.is_contiguous()); TORCH_CHECK(value.is_contiguous()); - TORCH_CHECK( - query.scalar_type() == at::ScalarType::Float && - key.scalar_type() == at::ScalarType::Float && - value.scalar_type() == at::ScalarType::Float, - "Only float32 type is supported for now"); - at::Tensor attn_bias; if (attn_bias_.has_value()) { attn_bias = *attn_bias_; @@ -667,22 +720,58 @@ efficient_attention_forward_generic( int64_t M = query.size(1); int64_t N = key.size(1); int64_t K = query.size(2); - at::Tensor res = at::zeros({B, M, K}, query.options()); - at::Tensor logsumexp = - at::empty({B, compute_logsumexp ? M : 0}, query.options()); - - typedef float scalar_t; - using AK = AttentionKernel; - - dim3 grid(AK::kNumBlocksX, AK::getNumBlocksY(M), B); - dim3 block(AK::kWarpSize, AK::kNumWarpsPerBlock, 1); - - attention_kernel_batched<<>>( - res.packed_accessor(), - logsumexp.packed_accessor(), - query.packed_accessor(), - key.packed_accessor(), - value.packed_accessor()); + + using accum_t = float; + + at::Tensor res; + at::Tensor logsumexp = at::empty( + {B, compute_logsumexp ? M : 0}, + query.options().dtype(at::ScalarType::Float)); + + if (query.scalar_type() == at::ScalarType::Float) { +#ifdef FP16_ONLY_USE_TENSORCORES + TORCH_CHECK( + false, "Only support f32 with FP16_ONLY_USE_TENSORCORES defined"); +#else + using scalar_t = float; + using output_t = float; + using AK = AttentionKernel; + using m = math; + + res = at::zeros( + {B, M, K}, query.options().dtype(math::kAtScalarType)); + + dim3 grid(AK::kNumBlocksX, AK::getNumBlocksY(M), B); + dim3 block(AK::kWarpSize, AK::kNumWarpsPerBlock, 1); + + attention_kernel_batched<<>>( + math::packed_accessor<3>(res), + logsumexp.packed_accessor(), + m::packed_accessor<3>(query), + m::packed_accessor<3>(key), + m::packed_accessor<3>(value)); +#endif + } else if (query.scalar_type() == at::ScalarType::Half) { + using scalar_t = cutlass::half_t; + using output_t = float; + using AK = AttentionKernel; + using m = math; + + res = at::zeros( + {B, M, K}, query.options().dtype(math::kAtScalarType)); + + dim3 grid(AK::kNumBlocksX, AK::getNumBlocksY(M), B); + dim3 block(AK::kWarpSize, AK::kNumWarpsPerBlock, 1); + + attention_kernel_batched<<>>( + math::packed_accessor<3>(res), + logsumexp.packed_accessor(), + m::packed_accessor<3>(query), + m::packed_accessor<3>(key), + m::packed_accessor<3>(value)); + } else { + TORCH_CHECK(false, "Only fp32 & half supported at the moment"); + } AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(res, logsumexp, int64_t(), int64_t()); diff --git a/xformers/ops.py b/xformers/ops.py index f4effd7949..a1edb7738c 100644 --- a/xformers/ops.py +++ b/xformers/ops.py @@ -47,6 +47,7 @@ def no_such_operator(*args, **kwargs): class MemoryEfficientAttentionOp(torch.autograd.Function): FORWARD_OPERATOR = _get_xformers_operator("efficient_attention") SUPPORTED_DEVICES = {"cuda", "cpu"} + SUPPORTED_DTYPES = {torch.float} SUPPORTED_MAX_K: float = 32 SUPPORTS_ATTN_BIAS = True SUPPORTS_DROPOUT = True @@ -82,6 +83,7 @@ def backward(ctx, grad): class MemoryEfficientAttentionGenericForwardOp(MemoryEfficientAttentionOp): FORWARD_OPERATOR = _get_xformers_operator("efficient_attention_forward_generic") SUPPORTED_DEVICES = {"cuda"} + SUPPORTED_DTYPES = {torch.float, torch.half} SUPPORTED_MAX_K = math.inf SUPPORTS_ATTN_BIAS = False SUPPORTS_DROPOUT = False