Skip to content

Commit

Permalink
add warmup, gqa
Browse files Browse the repository at this point in the history
  • Loading branch information
mdattack committed Dec 17, 2024
1 parent 1d525ac commit 8adf7bf
Showing 1 changed file with 100 additions and 61 deletions.
161 changes: 100 additions & 61 deletions benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flashinfer import BatchDecodeWithPagedKVCacheWrapper

from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
from sglang.srt.utils import should_use_tensor_core


def benchmark_forward(
Expand Down Expand Up @@ -38,7 +39,15 @@ def time_fwd(func, *args, **kwargs):


def decode_attention_sglang(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, num_kv_splits
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits,
warmup=10,
):

k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
Expand All @@ -57,6 +66,20 @@ def decode_attention_sglang(
device="cuda",
)

for _ in range(warmup):
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)

f = time_fwd(
decode_attention_fwd,
q,
Expand All @@ -74,10 +97,15 @@ def decode_attention_sglang(
return f, o


def decode_attention_flashinfer():
def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=dtype,
num_attention_heads=head_num_q,
num_kv_heads=head_num_kv,
)
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD", use_tensor_cores=False
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
)

class FlashinferAttention(torch.autograd.Function):
Expand All @@ -92,6 +120,7 @@ def forward(
head_num_kv,
head_dim,
dtype,
warmup=10,
):
total_tokens = batch_size * kv_len
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
Expand All @@ -113,16 +142,17 @@ def forward(
data_type=dtype,
)

for _ in range(warmup):
o = flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, head_num_q, head_dim), kv_data
)

f = time_fwd(
flashinfer_decode_wrapper.forward,
q.contiguous().view(-1, head_num_q, head_dim),
kv_data,
)

o = flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, head_num_q, head_dim), kv_data
)

return f, o

return FlashinferAttention
Expand All @@ -144,14 +174,7 @@ def convert_to_cudnn_type(torch_type):


def decode_attention_cudnn(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
dtype,
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
):
# Prepare data: continuous q,k,v
dims_q = (batch_size, head_num_q, 1, head_dim)
Expand Down Expand Up @@ -251,6 +274,9 @@ def decode_attention_cudnn(
o: o_gpu,
}

for _ in range(warmup):
graph.execute(variant_pack, workspace)

f = time_fwd(
graph.execute,
variant_pack,
Expand All @@ -266,7 +292,7 @@ def calculate_diff():
batch_size = 64
kv_len = 4096
head_num_q = 64
head_num_kv = 64
head_num_kv = 8
head_dim = 128

q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
Expand All @@ -290,7 +316,7 @@ def calculate_diff():
num_kv_splits=8,
)

attn_flashinfer = decode_attention_flashinfer().apply
attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
_, output_flashinfer = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
Expand All @@ -316,51 +342,64 @@ def calculate_diff():
if __name__ == "__main__":
calculate_diff()

attn_flashinfer = decode_attention_flashinfer().apply
head_dim = 128
dtype = torch.float16
batch_size_range = [2**i for i in range(0, 8, 2)]
kv_len_range = [2**i for i in range(6, 13, 1)]
head_num_range = [32, 64]
configs = list(itertools.product(head_num_range, batch_size_range, kv_len_range))

for head_num, batch_size, kv_len in configs:
head_num_q = head_num_kv = head_num
q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
kv_data = (
torch.randn(
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
),
torch.randn(
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
),
)
us_cudnn, output_cudnn = decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
us_sglang, output_sglang = decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits=8,
)
us_flashinfer, _ = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
print(
head_num,
" ",
batch_size,
" ",
kv_len,
" ",
us_cudnn,
" ",
us_sglang,
" ",
us_flashinfer,
)
configs = list(itertools.product(batch_size_range, kv_len_range))

for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
attn_flashinfer = decode_attention_flashinfer(
dtype, head_num_q, head_num_kv
).apply
for batch_size, kv_len in configs:
q = torch.randn(
batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
)
kv_data = (
torch.randn(
batch_size * kv_len,
head_num_kv,
head_dim,
dtype=dtype,
device="cuda",
),
torch.randn(
batch_size * kv_len,
head_num_kv,
head_dim,
dtype=dtype,
device="cuda",
),
)
us_cudnn, output_cudnn = decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
us_sglang, output_sglang = decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits=8,
)
us_flashinfer, _ = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
print(
head_num_q,
" ",
head_num_kv,
" ",
batch_size,
" ",
kv_len,
" ",
us_cudnn,
" ",
us_sglang,
" ",
us_flashinfer,
)

0 comments on commit 8adf7bf

Please sign in to comment.