From d6d1b4dc79a9e8308f15559b9ec3273acdf8fcf9 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 9 Dec 2024 23:00:02 +0800 Subject: [PATCH] benchmark decoding attention kernel --- .../sglang_triton_vs_flashinfer.py | 172 ++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 benchmark/kernels/decoding_attention_triton/sglang_triton_vs_flashinfer.py diff --git a/benchmark/kernels/decoding_attention_triton/sglang_triton_vs_flashinfer.py b/benchmark/kernels/decoding_attention_triton/sglang_triton_vs_flashinfer.py new file mode 100644 index 00000000000..a56849f4818 --- /dev/null +++ b/benchmark/kernels/decoding_attention_triton/sglang_triton_vs_flashinfer.py @@ -0,0 +1,172 @@ +import itertools + +import torch +import triton +import triton.language as tl +from flashinfer import BatchDecodeWithPagedKVCacheWrapper + +from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd + + +def decode_attention_sglang( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, num_kv_splits +): + + k_buffer = kv_data[:, 0].view(-1, head_num_kv, head_dim).contiguous() + v_buffer = kv_data[:, 1].view(-1, head_num_kv, head_dim).contiguous() + o = torch.empty_like(q) + total_tokens = batch_size * kv_len + req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len) + b_req_idx = torch.arange(0, batch_size).to(0).int() + b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda") + max_len_in_batch = kv_len + sm_scale = 1.0 / (head_dim**0.5) + + attn_logits = torch.empty( + (batch_size, head_num_q, num_kv_splits, head_dim + 1), + dtype=torch.float32, + device="cuda", + ) + + decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + ) + + return o + + +def decode_attention_flashinfer( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype +): + + total_tokens = batch_size * kv_len + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + kv_indices = torch.arange(0, total_tokens).to(0).int() + kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32, device="cuda") + + flashinfer_decode_wrapper.end_forward() + flashinfer_decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + head_num_q, + head_num_kv, + head_dim, + 1, + pos_encoding_mode="NONE", + data_type=dtype, + ) + o = flashinfer_decode_wrapper.forward( + q.contiguous().view(-1, head_num_q, head_dim), kv_data + ) + + return o + + +def calculate_diff(): + + dtype = torch.bfloat16 + batch_size = 4 + kv_len = 16 + head_num_q = 32 + head_num_kv = 32 + head_dim = 128 + + q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda") + kv_data = torch.randn( + batch_size * kv_len, 2, head_num_kv, head_dim, dtype=dtype, device="cuda" + ) + + output_sglang = decode_attention_sglang( + q, + kv_data, + batch_size, + kv_len, + head_num_q, + head_num_kv, + head_dim, + num_kv_splits=8, + ) + output_flashinfer = decode_attention_flashinfer( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype=dtype + ) + + print(f"SGLang output={output_sglang}") + print(f"FlashInfer output={output_flashinfer}") + if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2): + print("✅ SGLang[Triton] and FlashInfer match") + else: + print("❌ SGLang[Triton] and FlashInfer differ") + + +head_dim = 128 +dtype = torch.float16 +batch_size_range = [2**i for i in range(0, 8, 2)] +kv_len_range = [2**i for i in range(6, 13, 1)] +head_num_range = [32, 64] +configs = list(itertools.product(head_num_range, batch_size_range, kv_len_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "kv_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["sglang_triton", "flashinfer"], + line_names=["SGLang[triton]", "FlashInfer"], + styles=[("green", "-"), ("red", "-")], + ylabel="us", + plot_name="decode-attention-performance", + args={}, + ) +) +def benchmark(head_num, batch_size, kv_len, provider): + head_num_q = head_num_kv = head_num + q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda") + kv_data = torch.randn( + batch_size * kv_len, 2, head_num_kv, head_dim, dtype=dtype, device="cuda" + ) + quantiles = [0.5, 0.2, 0.8] + if provider == "sglang_triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: decode_attention_sglang( + q, + kv_data, + batch_size, + kv_len, + head_num_q, + head_num_kv, + head_dim, + num_kv_splits=8, + ), + quantiles=quantiles, + ) + if provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: decode_attention_flashinfer( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype + ), + quantiles=quantiles, + ) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") + global flashinfer_decode_wrapper + flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD", use_tensor_cores=False + ) + + calculate_diff() + + benchmark.run(print_data=True)