From 8adf7bfd8c481956b686e72c8d8ce36d29d85227 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 17 Dec 2024 16:20:48 +0800 Subject: [PATCH] add warmup, gqa --- .../triton_flashinfer_cudnn.py | 161 +++++++++++------- 1 file changed, 100 insertions(+), 61 deletions(-) diff --git a/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py b/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py index 8a748ec9269..f8c87d48db7 100644 --- a/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py +++ b/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py @@ -9,6 +9,7 @@ from flashinfer import BatchDecodeWithPagedKVCacheWrapper from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd +from sglang.srt.utils import should_use_tensor_core def benchmark_forward( @@ -38,7 +39,15 @@ def time_fwd(func, *args, **kwargs): def decode_attention_sglang( - q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, num_kv_splits + q, + kv_data, + batch_size, + kv_len, + head_num_q, + head_num_kv, + head_dim, + num_kv_splits, + warmup=10, ): k_buffer = kv_data[0].view(-1, head_num_kv, head_dim) @@ -57,6 +66,20 @@ def decode_attention_sglang( device="cuda", ) + for _ in range(warmup): + decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + ) + f = time_fwd( decode_attention_fwd, q, @@ -74,10 +97,15 @@ def decode_attention_sglang( return f, o -def decode_attention_flashinfer(): +def decode_attention_flashinfer(dtype, head_num_q, head_num_kv): workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") + use_tensor_cores = should_use_tensor_core( + kv_cache_dtype=dtype, + num_attention_heads=head_num_q, + num_kv_heads=head_num_kv, + ) flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, "NHD", use_tensor_cores=False + workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores ) class FlashinferAttention(torch.autograd.Function): @@ -92,6 +120,7 @@ def forward( head_num_kv, head_dim, dtype, + warmup=10, ): total_tokens = batch_size * kv_len kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len @@ -113,16 +142,17 @@ def forward( data_type=dtype, ) + for _ in range(warmup): + o = flashinfer_decode_wrapper.forward( + q.contiguous().view(-1, head_num_q, head_dim), kv_data + ) + f = time_fwd( flashinfer_decode_wrapper.forward, q.contiguous().view(-1, head_num_q, head_dim), kv_data, ) - o = flashinfer_decode_wrapper.forward( - q.contiguous().view(-1, head_num_q, head_dim), kv_data - ) - return f, o return FlashinferAttention @@ -144,14 +174,7 @@ def convert_to_cudnn_type(torch_type): def decode_attention_cudnn( - q, - kv_data, - batch_size, - kv_len, - head_num_q, - head_num_kv, - head_dim, - dtype, + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10 ): # Prepare data: continuous q,k,v dims_q = (batch_size, head_num_q, 1, head_dim) @@ -251,6 +274,9 @@ def decode_attention_cudnn( o: o_gpu, } + for _ in range(warmup): + graph.execute(variant_pack, workspace) + f = time_fwd( graph.execute, variant_pack, @@ -266,7 +292,7 @@ def calculate_diff(): batch_size = 64 kv_len = 4096 head_num_q = 64 - head_num_kv = 64 + head_num_kv = 8 head_dim = 128 q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda") @@ -290,7 +316,7 @@ def calculate_diff(): num_kv_splits=8, ) - attn_flashinfer = decode_attention_flashinfer().apply + attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply _, output_flashinfer = attn_flashinfer( q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype ) @@ -316,51 +342,64 @@ def calculate_diff(): if __name__ == "__main__": calculate_diff() - attn_flashinfer = decode_attention_flashinfer().apply head_dim = 128 dtype = torch.float16 batch_size_range = [2**i for i in range(0, 8, 2)] kv_len_range = [2**i for i in range(6, 13, 1)] - head_num_range = [32, 64] - configs = list(itertools.product(head_num_range, batch_size_range, kv_len_range)) - - for head_num, batch_size, kv_len in configs: - head_num_q = head_num_kv = head_num - q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda") - kv_data = ( - torch.randn( - batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda" - ), - torch.randn( - batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda" - ), - ) - us_cudnn, output_cudnn = decode_attention_cudnn( - q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype - ) - us_sglang, output_sglang = decode_attention_sglang( - q, - kv_data, - batch_size, - kv_len, - head_num_q, - head_num_kv, - head_dim, - num_kv_splits=8, - ) - us_flashinfer, _ = attn_flashinfer( - q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype - ) - print( - head_num, - " ", - batch_size, - " ", - kv_len, - " ", - us_cudnn, - " ", - us_sglang, - " ", - us_flashinfer, - ) + configs = list(itertools.product(batch_size_range, kv_len_range)) + + for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]: + attn_flashinfer = decode_attention_flashinfer( + dtype, head_num_q, head_num_kv + ).apply + for batch_size, kv_len in configs: + q = torch.randn( + batch_size, head_num_q, head_dim, dtype=dtype, device="cuda" + ) + kv_data = ( + torch.randn( + batch_size * kv_len, + head_num_kv, + head_dim, + dtype=dtype, + device="cuda", + ), + torch.randn( + batch_size * kv_len, + head_num_kv, + head_dim, + dtype=dtype, + device="cuda", + ), + ) + us_cudnn, output_cudnn = decode_attention_cudnn( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype + ) + us_sglang, output_sglang = decode_attention_sglang( + q, + kv_data, + batch_size, + kv_len, + head_num_q, + head_num_kv, + head_dim, + num_kv_splits=8, + ) + us_flashinfer, _ = attn_flashinfer( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype + ) + print( + head_num_q, + " ", + head_num_kv, + " ", + batch_size, + " ", + kv_len, + " ", + us_cudnn, + " ", + us_sglang, + " ", + us_flashinfer, + )