From 2cf42b588d5f768da30e02e935ce7fdf2af3e7b8 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 15 Dec 2024 13:32:34 +0800 Subject: [PATCH 1/3] add a benchmark for hf/vllm/sglang rmsnorm --- .../kernels/rmsnorm/benchmark_rmsnorm.py | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 benchmark/kernels/rmsnorm/benchmark_rmsnorm.py diff --git a/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py new file mode 100644 index 00000000000..0790b31ca35 --- /dev/null +++ b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py @@ -0,0 +1,178 @@ +import itertools +import torch +import triton +import triton.language as tl +from typing import Optional, Union, Tuple +from torch import nn +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from vllm import _custom_ops as vllm_ops + +class HuggingFaceRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + +def rmsnorm_naive(x: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, eps: float = 1e-6): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + +def rmsnorm_flashinfer(x: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, eps: float = 1e-6): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + +def rmsnorm_vllm(x: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, eps: float = 1e-6): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive(x.clone(), weight, residual.clone() if residual is not None else None) + output_flashinfer = rmsnorm_flashinfer(x.clone(), weight, residual.clone() if residual is not None else None) + output_vllm = rmsnorm_vllm(x.clone(), weight, residual.clone() if residual is not None else None) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"VLLM output={output_vllm}") + + if (torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2) and + torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range)) + +def get_benchmark(use_residual): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm"], + line_names=["HuggingFace", "FlashInfer", "VLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual", + args={}, + ) + ) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive(x.clone(), weight, residual.clone() if residual is not None else None), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_flashinfer(x.clone(), weight, residual.clone() if residual is not None else None), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm(x.clone(), weight, residual.clone() if residual is not None else None), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--use_residual", action="store_true", help="Whether to use residual connection") + parser.add_argument("--save_path", type=str, default="./configs/benchmark_ops/rmsnorm/", help="Path to save rmsnorm benchmark results") + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) From 650fd24b96e03722a83e519cf29b067ace163709 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 15 Dec 2024 13:43:51 +0800 Subject: [PATCH 2/3] format --- .../kernels/rmsnorm/benchmark_rmsnorm.py | 115 +++++++++++++----- 1 file changed, 84 insertions(+), 31 deletions(-) diff --git a/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py index 0790b31ca35..ba37556a728 100644 --- a/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py +++ b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py @@ -1,12 +1,14 @@ import itertools +from typing import Optional, Tuple, Union + import torch import triton import triton.language as tl -from typing import Optional, Union, Tuple -from torch import nn from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from torch import nn from vllm import _custom_ops as vllm_ops + class HuggingFaceRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() @@ -32,48 +34,66 @@ def forward( else: return x, residual -def rmsnorm_naive(x: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, eps: float = 1e-6): + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) naive_norm.weight = nn.Parameter(weight) naive_norm = naive_norm.to(x.device) - + orig_shape = x.shape x = x.view(-1, x.shape[-1]) if residual is not None: residual = residual.view(-1, residual.shape[-1]) - + output = naive_norm(x, residual) - + if isinstance(output, tuple): output = (output[0].view(orig_shape), output[1].view(orig_shape)) else: output = output.view(orig_shape) return output -def rmsnorm_flashinfer(x: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, eps: float = 1e-6): + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): orig_shape = x.shape x = x.view(-1, x.shape[-1]) if residual is not None: residual = residual.view(-1, residual.shape[-1]) - + if residual is not None: fused_add_rmsnorm(x, residual, weight, eps) output = (x, residual) else: output = rmsnorm(x, weight, eps) - + if isinstance(output, tuple): output = (output[0].view(orig_shape), output[1].view(orig_shape)) else: - output = output.view(orig_shape) + output = output.view(orig_shape) return output -def rmsnorm_vllm(x: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, eps: float = 1e-6): + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): orig_shape = x.shape x = x.view(-1, x.shape[-1]) if residual is not None: residual = residual.view(-1, residual.shape[-1]) - + if residual is not None: vllm_ops.fused_add_rms_norm(x, residual, weight, eps) output = (x, residual) @@ -81,22 +101,29 @@ def rmsnorm_vllm(x: torch.Tensor, weight: torch.Tensor, residual: Optional[torch out = torch.empty_like(x) vllm_ops.rms_norm(out, x, weight, eps) output = out - + if isinstance(output, tuple): output = (output[0].view(orig_shape), output[1].view(orig_shape)) else: output = output.view(orig_shape) return output + def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): dtype = torch.bfloat16 x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda") residual = torch.randn_like(x) if use_residual else None - output_naive = rmsnorm_naive(x.clone(), weight, residual.clone() if residual is not None else None) - output_flashinfer = rmsnorm_flashinfer(x.clone(), weight, residual.clone() if residual is not None else None) - output_vllm = rmsnorm_vllm(x.clone(), weight, residual.clone() if residual is not None else None) + output_naive = rmsnorm_naive( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_vllm = rmsnorm_vllm( + x.clone(), weight, residual.clone() if residual is not None else None + ) if use_residual: output_naive = output_naive[0] @@ -106,18 +133,21 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): print(f"Naive output={output_naive}") print(f"FlashInfer output={output_flashinfer}") print(f"VLLM output={output_vllm}") - - if (torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2) and - torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)): + + if torch.allclose( + output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 + ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): print("✅ All implementations match") else: print("❌ Implementations differ") + batch_size_range = [2**i for i in range(0, 7, 2)] seq_length_range = [2**i for i in range(6, 11, 1)] head_num_range = [32, 48] configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range)) + def get_benchmark(use_residual): @triton.testing.perf_report( triton.testing.Benchmark( @@ -135,43 +165,66 @@ def get_benchmark(use_residual): def benchmark(head_num, batch_size, seq_len, provider): dtype = torch.bfloat16 hidden_size = head_num * 128 # assuming head_dim = 128 - + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda") residual = torch.randn_like(x) if use_residual else None - + quantiles = [0.5, 0.2, 0.8] - + if provider == "huggingface": ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_naive(x.clone(), weight, residual.clone() if residual is not None else None), + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), quantiles=quantiles, ) elif provider == "flashinfer": ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_flashinfer(x.clone(), weight, residual.clone() if residual is not None else None), + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), quantiles=quantiles, ) else: ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_vllm(x.clone(), weight, residual.clone() if residual is not None else None), + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), quantiles=quantiles, ) - + return 1000 * ms, 1000 * max_ms, 1000 * min_ms - + return benchmark + if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() - parser.add_argument("--use_residual", action="store_true", help="Whether to use residual connection") - parser.add_argument("--save_path", type=str, default="./configs/benchmark_ops/rmsnorm/", help="Path to save rmsnorm benchmark results") + parser.add_argument( + "--use_residual", action="store_true", help="Whether to use residual connection" + ) + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) args = parser.parse_args() # Run correctness test - calculate_diff(batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual) - + calculate_diff( + batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual + ) + # Get the benchmark function with proper use_residual setting benchmark = get_benchmark(args.use_residual) # Run performance benchmark From 172ff513ebc1bb0457d642c17a35b735cd6293a5 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 15 Dec 2024 13:51:02 +0800 Subject: [PATCH 3/3] upd --- benchmark/kernels/rmsnorm/benchmark_rmsnorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py index ba37556a728..ad7b180ce1d 100644 --- a/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py +++ b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py @@ -155,7 +155,7 @@ def get_benchmark(use_residual): x_vals=[list(_) for _ in configs], line_arg="provider", line_vals=["huggingface", "flashinfer", "vllm"], - line_names=["HuggingFace", "FlashInfer", "VLLM"], + line_names=["HuggingFace", "FlashInfer", "vLLM"], styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="us", plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",