-
Notifications
You must be signed in to change notification settings - Fork 711
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[benchmark] Add fused_moe_triton benchmark and tuning tools (#2225)
Co-authored-by: Lianmin Zheng <[email protected]> Co-authored-by: HAI <[email protected]>
- Loading branch information
1 parent
419a57e
commit 262e370
Showing
4 changed files
with
732 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
## Benchmark Kernels | ||
|
||
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels. | ||
|
||
### Tuning Tool | ||
|
||
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures. | ||
|
||
Example usage: | ||
```bash | ||
# Tune Qwen2-57B with FP8 and TP=4 | ||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ | ||
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ | ||
--tp-size 4 \ | ||
--dtype fp8_w8a8 \ | ||
--tune | ||
|
||
# Tune Mixtral-8x7B with default settings | ||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ | ||
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \ | ||
--tune | ||
``` | ||
|
||
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/` to use it in `sglang`. | ||
|
||
### Performance Comparison Tool | ||
|
||
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types. | ||
|
||
Example usage: | ||
```bash | ||
# Compare with default settings (Mixtral model) | ||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py | ||
|
||
# Compare with FP8 mode for Qwen2-57B | ||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ | ||
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ | ||
--use-fp8 | ||
|
||
# Compare with custom TP size | ||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ | ||
--tp-size 4 | ||
``` | ||
|
||
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). |
237 changes: 237 additions & 0 deletions
237
benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
import argparse | ||
import numbers | ||
from typing import Optional | ||
|
||
import torch | ||
import triton | ||
from torch.nn import init | ||
from torch.nn.parameter import Parameter | ||
from transformers import AutoConfig | ||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm | ||
from vllm.model_executor.layers.fused_moe.fused_moe import ( | ||
get_moe_configs as get_moe_configs_vllm, | ||
) | ||
from vllm.utils import FlexibleArgumentParser | ||
|
||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang | ||
from sglang.srt.layers.fused_moe_triton.fused_moe import ( | ||
get_moe_configs as get_moe_configs_sglang, | ||
) | ||
|
||
|
||
def get_model_config(model_name: str, tp_size: int): | ||
"""Get model configuration parameters""" | ||
config = AutoConfig.from_pretrained(model_name) | ||
|
||
if config.architectures[0] == "DbrxForCausalLM": | ||
E = config.ffn_config.moe_num_experts | ||
topk = config.ffn_config.moe_top_k | ||
intermediate_size = config.ffn_config.ffn_hidden_size | ||
shard_intermediate_size = 2 * intermediate_size // tp_size | ||
elif config.architectures[0] == "JambaForCausalLM": | ||
E = config.num_experts | ||
topk = config.num_experts_per_tok | ||
intermediate_size = config.intermediate_size | ||
shard_intermediate_size = 2 * intermediate_size // tp_size | ||
elif config.architectures[0] == "Qwen2MoeForCausalLM": | ||
E = config.num_experts | ||
topk = config.num_experts_per_tok | ||
intermediate_size = config.moe_intermediate_size | ||
shard_intermediate_size = 2 * intermediate_size // tp_size | ||
else: | ||
# Default: Mixtral, Grok1, etc. | ||
E = config.num_local_experts | ||
topk = config.num_experts_per_tok | ||
intermediate_size = config.intermediate_size | ||
shard_intermediate_size = 2 * intermediate_size // tp_size | ||
|
||
return { | ||
"num_experts": E, | ||
"topk": topk, | ||
"hidden_size": config.hidden_size, | ||
"shard_intermediate_size": shard_intermediate_size, | ||
"dtype": config.torch_dtype, | ||
} | ||
|
||
|
||
def fused_moe_vllm_api( | ||
x, | ||
w1, | ||
w2, | ||
input_gating, | ||
topk, | ||
use_fp8_w8a8=False, | ||
w1_scale=None, | ||
w2_scale=None, | ||
a1_scale=None, | ||
a2_scale=None, | ||
): | ||
return fused_moe_vllm( | ||
x, | ||
w1, | ||
w2, | ||
input_gating, | ||
topk, | ||
renormalize=True, | ||
inplace=True, | ||
use_fp8_w8a8=use_fp8_w8a8, | ||
w1_scale=w1_scale, | ||
w2_scale=w2_scale, | ||
a1_scale=a1_scale, | ||
a2_scale=a2_scale, | ||
) | ||
|
||
|
||
def fused_moe_sglang_api( | ||
x, | ||
w1, | ||
w2, | ||
input_gating, | ||
topk, | ||
use_fp8_w8a8=False, | ||
w1_scale=None, | ||
w2_scale=None, | ||
a1_scale=None, | ||
a2_scale=None, | ||
): | ||
return fused_moe_sglang( | ||
x, | ||
w1, | ||
w2, | ||
input_gating, | ||
topk, | ||
renormalize=True, | ||
inplace=True, | ||
use_fp8_w8a8=use_fp8_w8a8, | ||
w1_scale=w1_scale, | ||
w2_scale=w2_scale, | ||
a1_scale=a1_scale, | ||
a2_scale=a2_scale, | ||
) | ||
|
||
|
||
@triton.testing.perf_report( | ||
triton.testing.Benchmark( | ||
x_names=["batch_size"], | ||
x_vals=list(range(1, 513)), | ||
line_arg="provider", | ||
line_vals=[ | ||
"vllm_fused_moe_triton", | ||
"sglang_fused_moe_triton", | ||
], | ||
line_names=[ | ||
"vllm_fused_moe_triton", | ||
"sglang_fused_moe_triton", | ||
], | ||
styles=[ | ||
("blue", "-"), | ||
("green", "-"), | ||
], | ||
ylabel="Time (ms)", | ||
plot_name="fused-moe-performance", | ||
args={}, | ||
) | ||
) | ||
def benchmark(batch_size, provider, model_config, use_fp8=False): | ||
print(f"benchmark for batch_size={batch_size}") | ||
torch.set_default_device("cuda") | ||
torch.cuda.manual_seed_all(0) | ||
|
||
num_tokens = batch_size | ||
num_experts = model_config["num_experts"] | ||
hidden_size = model_config["hidden_size"] | ||
shard_intermediate_size = model_config["shard_intermediate_size"] | ||
topk = model_config["topk"] | ||
dtype = model_config["dtype"] | ||
|
||
x = torch.randn(num_tokens, hidden_size, dtype=dtype) | ||
|
||
if use_fp8: | ||
init_dtype = dtype | ||
w1 = torch.randn( | ||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype | ||
) | ||
w2 = torch.randn( | ||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype | ||
) | ||
w1 = w1.to(torch.float8_e4m3fn) | ||
w2 = w2.to(torch.float8_e4m3fn) | ||
w1_scale = torch.randn(num_experts, dtype=torch.float32) | ||
w2_scale = torch.randn(num_experts, dtype=torch.float32) | ||
a1_scale = torch.randn(1, dtype=torch.float32) | ||
a2_scale = torch.randn(1, dtype=torch.float32) | ||
else: | ||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype) | ||
w2 = torch.randn( | ||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype | ||
) | ||
w1_scale = w2_scale = a1_scale = a2_scale = None | ||
|
||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) | ||
|
||
# Warmup | ||
api_func = ( | ||
fused_moe_vllm_api | ||
if provider == "vllm_fused_moe_triton" | ||
else fused_moe_sglang_api | ||
) | ||
for _ in range(10): | ||
y = api_func( | ||
x, | ||
w1, | ||
w2, | ||
input_gating, | ||
topk, | ||
use_fp8_w8a8=use_fp8, | ||
w1_scale=w1_scale, | ||
w2_scale=w2_scale, | ||
a1_scale=a1_scale, | ||
a2_scale=a2_scale, | ||
) | ||
torch.cuda.synchronize() | ||
|
||
quantiles = [0.5, 0.2, 0.8] | ||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
lambda: api_func( | ||
x, | ||
w1, | ||
w2, | ||
input_gating, | ||
topk, | ||
use_fp8_w8a8=use_fp8, | ||
w1_scale=w1_scale, | ||
w2_scale=w2_scale, | ||
a1_scale=a1_scale, | ||
a2_scale=a2_scale, | ||
)[0], | ||
quantiles=quantiles, | ||
) | ||
return ms, min_ms, max_ms | ||
|
||
|
||
def main(): | ||
parser = FlexibleArgumentParser() | ||
parser.add_argument( | ||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" | ||
) | ||
parser.add_argument("--tp-size", type=int, default=2) | ||
parser.add_argument("--use-fp8", action="store_true") | ||
parser.add_argument( | ||
"--save-path", | ||
type=str, | ||
default="./configs/benchmark_ops/vllm_sglang_fused_moe/", | ||
) | ||
args = parser.parse_args() | ||
|
||
model_config = get_model_config(args.model, args.tp_size) | ||
benchmark.run( | ||
show_plots=True, | ||
print_data=True, | ||
save_path=args.save_path, | ||
model_config=model_config, | ||
use_fp8=args.use_fp8, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.