From 5e20c98a176f9ac6f9c864d512e2b5580444d735 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 9 Dec 2024 11:14:43 +0800 Subject: [PATCH 1/7] add fused moe unittest --- benchmark/kernels/fused_moe_triton/README.md | 9 +- ...280,device_name=NVIDIA_A800-SXM4-80GB.json | 146 ++++++++++++++++++ ...640,device_name=NVIDIA_A800-SXM4-80GB.json | 146 ++++++++++++++++++ .../srt/layers/quantization/__init__.py | 2 +- python/sglang/srt/layers/quantization/fp8.py | 2 +- srt/layers/base.py | 1 + srt/layers/fused_moe_triton/__init__.py | 1 + srt/layers/quantization/fp8.py | 1 + test/srt/run_suite.py | 1 + test/srt/test_fused_moe.py | 97 ++++++++++++ 10 files changed, 402 insertions(+), 4 deletions(-) create mode 100644 python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json create mode 100644 python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json create mode 100644 srt/layers/base.py create mode 100644 srt/layers/fused_moe_triton/__init__.py create mode 100644 srt/layers/quantization/fp8.py create mode 100644 test/srt/test_fused_moe.py diff --git a/benchmark/kernels/fused_moe_triton/README.md b/benchmark/kernels/fused_moe_triton/README.md index ba29ede5099..b2b700dfaf2 100644 --- a/benchmark/kernels/fused_moe_triton/README.md +++ b/benchmark/kernels/fused_moe_triton/README.md @@ -10,7 +10,7 @@ Example usage: ```bash # Tune Qwen2-57B with FP8 and TP=4 python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ - --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --model Qwen/Qwen2-57B-A14B-Instruct \ --tp-size 4 \ --dtype fp8_w8a8 \ --tune @@ -34,7 +34,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri # Compare with FP8 mode for Qwen2-57B python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ - --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --model Qwen/Qwen2-57B-A14B-Instruct \ --use-fp8 # Compare with custom TP size @@ -43,3 +43,8 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri ``` The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). + +- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel. + +Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`. + diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 00000000000..283ffd8ff1d --- /dev/null +++ b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 00000000000..8a18afe7d6d --- /dev/null +++ b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 3e2078c4a4d..066ff153efe 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index acdce0b8cbd..c2c35a42b25 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING import torch from torch.nn import Module diff --git a/srt/layers/base.py b/srt/layers/base.py new file mode 100644 index 00000000000..0519ecba6ea --- /dev/null +++ b/srt/layers/base.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/srt/layers/fused_moe_triton/__init__.py b/srt/layers/fused_moe_triton/__init__.py new file mode 100644 index 00000000000..0519ecba6ea --- /dev/null +++ b/srt/layers/fused_moe_triton/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/srt/layers/quantization/fp8.py b/srt/layers/quantization/fp8.py new file mode 100644 index 00000000000..0519ecba6ea --- /dev/null +++ b/srt/layers/quantization/fp8.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5035810f86a..abc5015da4a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,6 +17,7 @@ "test_eval_accuracy_mini.py", "test_get_weights_by_name.py", "test_gguf.py", + "test_fused_moe.py", "test_input_embeddings.py", "test_json_constrained.py", "test_large_max_new_tokens.py", diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py new file mode 100644 index 00000000000..c5ef1c584cc --- /dev/null +++ b/test/srt/test_fused_moe.py @@ -0,0 +1,97 @@ +import unittest +import torch +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm + +class TestFusedMOE(unittest.TestCase): + NUM_EXPERTS = [8, 64] + TOP_KS = [2, 6] + + def torch_naive_moe(self, a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): + if use_fp8_w8a8: + # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 + capability = torch.cuda.get_device_capability() + if not (capability[0] >= 9 or capability == (8, 9)): + return + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + score = torch.randn((m, e), device="cuda", dtype=dtype) + + w1_scale = torch.randn(e, dtype=torch.float32, device="cuda") + w2_scale = torch.randn(e, dtype=torch.float32, device="cuda") + a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") + a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") + + sglang_output = fused_moe( + a, w1, w2, score, topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale + ) + + vllm_output = fused_moe_vllm( + a, w1, w2, score, topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale + ) + + torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) + + else: + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + torch_output = self.torch_naive_moe(a, w1, w2, score, topk) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + def test_various_configurations(self): + m_values = [1, 33, 64, 222, 1024 * 128] + n_values = [128, 1024, 2048] + k_values = [128, 511, 1024] + dtypes = [torch.float16, torch.bfloat16] + fp8_modes = [False, True] + + for m in m_values: + for n in n_values: + for k in k_values: + for e in self.NUM_EXPERTS: + for topk in self.TOP_KS: + for dtype in dtypes: + for use_fp8_w8a8 in fp8_modes: + with self.subTest(m=m, n=n, k=k, e=e, topk=topk, dtype=dtype, fp8=use_fp8_w8a8): + self._test_case(m, n, k, e, topk, dtype, use_fp8_w8a8=use_fp8_w8a8) + +if __name__ == "__main__": + unittest.main() + From 96028cafe02b7b62161f6d9c1f0bbe6b13cb7c8d Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 9 Dec 2024 11:16:35 +0800 Subject: [PATCH 2/7] refine --- srt/layers/base.py | 1 - srt/layers/fused_moe_triton/__init__.py | 1 - srt/layers/quantization/fp8.py | 1 - 3 files changed, 3 deletions(-) delete mode 100644 srt/layers/base.py delete mode 100644 srt/layers/fused_moe_triton/__init__.py delete mode 100644 srt/layers/quantization/fp8.py diff --git a/srt/layers/base.py b/srt/layers/base.py deleted file mode 100644 index 0519ecba6ea..00000000000 --- a/srt/layers/base.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/srt/layers/fused_moe_triton/__init__.py b/srt/layers/fused_moe_triton/__init__.py deleted file mode 100644 index 0519ecba6ea..00000000000 --- a/srt/layers/fused_moe_triton/__init__.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/srt/layers/quantization/fp8.py b/srt/layers/quantization/fp8.py deleted file mode 100644 index 0519ecba6ea..00000000000 --- a/srt/layers/quantization/fp8.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file From 37cae671ac913bac23d353039f4ad2fda458d9d3 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 9 Dec 2024 11:18:18 +0800 Subject: [PATCH 3/7] refine --- python/sglang/srt/layers/quantization/__init__.py | 2 +- python/sglang/srt/layers/quantization/fp8.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 066ff153efe..3e2078c4a4d 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index c2c35a42b25..17266d7f2aa 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,8 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING - +from typing import Any, Callable, Dict, List, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter From defd24f5a9a2bbf1b873c5aa35ec839a676884cd Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 9 Dec 2024 11:19:20 +0800 Subject: [PATCH 4/7] refine --- python/sglang/srt/layers/quantization/fp8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 17266d7f2aa..acdce0b8cbd 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -2,6 +2,7 @@ import logging from typing import Any, Callable, Dict, List, Optional + import torch from torch.nn import Module from torch.nn.parameter import Parameter From 89174c8f3482fcb002b97f17571ce6d6aefbf666 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Mon, 9 Dec 2024 12:17:53 +0800 Subject: [PATCH 5/7] lint --- test/srt/test_fused_moe.py | 59 ++++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index c5ef1c584cc..7b50c551a82 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -1,13 +1,16 @@ import unittest + import torch +from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm + from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm + class TestFusedMOE(unittest.TestCase): NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] - + def torch_naive_moe(self, a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -19,10 +22,12 @@ def torch_naive_moe(self, a, w1, w2, score, topk): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[ + i + ].transpose(0, 1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): if use_fp8_w8a8: @@ -30,7 +35,7 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): capability = torch.cuda.get_device_capability() if not (capability[0] >= 9 or capability == (8, 9)): return - + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -44,23 +49,31 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") sglang_output = fused_moe( - a, w1, w2, score, topk, + a, + w1, + w2, + score, + topk, renormalize=False, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, - a2_scale=a2_scale + a2_scale=a2_scale, ) vllm_output = fused_moe_vllm( - a, w1, w2, score, topk, + a, + w1, + w2, + score, + topk, renormalize=False, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, - a2_scale=a2_scale + a2_scale=a2_scale, ) torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) @@ -70,7 +83,7 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - + triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = self.torch_naive_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -89,9 +102,25 @@ def test_various_configurations(self): for topk in self.TOP_KS: for dtype in dtypes: for use_fp8_w8a8 in fp8_modes: - with self.subTest(m=m, n=n, k=k, e=e, topk=topk, dtype=dtype, fp8=use_fp8_w8a8): - self._test_case(m, n, k, e, topk, dtype, use_fp8_w8a8=use_fp8_w8a8) + with self.subTest( + m=m, + n=n, + k=k, + e=e, + topk=topk, + dtype=dtype, + fp8=use_fp8_w8a8, + ): + self._test_case( + m, + n, + k, + e, + topk, + dtype, + use_fp8_w8a8=use_fp8_w8a8, + ) + if __name__ == "__main__": unittest.main() - From bbf8de56bd47cb30b928d9727e954a804b0515ee Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Mon, 9 Dec 2024 12:19:45 +0800 Subject: [PATCH 6/7] lint --- benchmark/kernels/fused_moe_triton/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmark/kernels/fused_moe_triton/README.md b/benchmark/kernels/fused_moe_triton/README.md index b2b700dfaf2..2a3e37f6874 100644 --- a/benchmark/kernels/fused_moe_triton/README.md +++ b/benchmark/kernels/fused_moe_triton/README.md @@ -47,4 +47,3 @@ The benchmark results will be saved as plots and data files in the specified out - `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel. Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`. - From cd361a6b9de9f41a07ee6c682c7dda9461c9b8db Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 8 Dec 2024 22:45:53 -0800 Subject: [PATCH 7/7] Update test/srt/run_suite.py --- test/srt/run_suite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index abc5015da4a..cb6a60612dd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -15,9 +15,9 @@ "test_double_sparsity.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", + "test_fused_moe.py", "test_get_weights_by_name.py", "test_gguf.py", - "test_fused_moe.py", "test_input_embeddings.py", "test_json_constrained.py", "test_large_max_new_tokens.py",