diff --git a/benchmark/kernels/fused_moe_triton/README.md b/benchmark/kernels/fused_moe_triton/README.md new file mode 100644 index 00000000000..ba29ede5099 --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/README.md @@ -0,0 +1,45 @@ +## Benchmark Kernels + +This directory contains benchmarking tools for MoE (Mixture of Experts) kernels. + +### Tuning Tool + +- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures. + +Example usage: +```bash +# Tune Qwen2-57B with FP8 and TP=4 +python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ + --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --tp-size 4 \ + --dtype fp8_w8a8 \ + --tune + +# Tune Mixtral-8x7B with default settings +python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ + --model mistralai/Mixtral-8x7B-Instruct-v0.1 \ + --tune +``` + +After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/` to use it in `sglang`. + +### Performance Comparison Tool + +- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types. + +Example usage: +```bash +# Compare with default settings (Mixtral model) +python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py + +# Compare with FP8 mode for Qwen2-57B +python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ + --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --use-fp8 + +# Compare with custom TP size +python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ + --tp-size 4 +``` + +The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py new file mode 100644 index 00000000000..33b85f40e3b --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -0,0 +1,237 @@ +import argparse +import numbers +from typing import Optional + +import torch +import triton +from torch.nn import init +from torch.nn.parameter import Parameter +from transformers import AutoConfig +from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm +from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_moe_configs as get_moe_configs_vllm, +) +from vllm.utils import FlexibleArgumentParser + +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang +from sglang.srt.layers.fused_moe_triton.fused_moe import ( + get_moe_configs as get_moe_configs_sglang, +) + + +def get_model_config(model_name: str, tp_size: int): + """Get model configuration parameters""" + config = AutoConfig.from_pretrained(model_name) + + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "Qwen2MoeForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + else: + # Default: Mixtral, Grok1, etc. + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + + return { + "num_experts": E, + "topk": topk, + "hidden_size": config.hidden_size, + "shard_intermediate_size": shard_intermediate_size, + "dtype": config.torch_dtype, + } + + +def fused_moe_vllm_api( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, +): + return fused_moe_vllm( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + +def fused_moe_sglang_api( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, +): + return fused_moe_sglang( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=list(range(1, 513)), + line_arg="provider", + line_vals=[ + "vllm_fused_moe_triton", + "sglang_fused_moe_triton", + ], + line_names=[ + "vllm_fused_moe_triton", + "sglang_fused_moe_triton", + ], + styles=[ + ("blue", "-"), + ("green", "-"), + ], + ylabel="Time (ms)", + plot_name="fused-moe-performance", + args={}, + ) +) +def benchmark(batch_size, provider, model_config, use_fp8=False): + print(f"benchmark for batch_size={batch_size}") + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + + num_tokens = batch_size + num_experts = model_config["num_experts"] + hidden_size = model_config["hidden_size"] + shard_intermediate_size = model_config["shard_intermediate_size"] + topk = model_config["topk"] + dtype = model_config["dtype"] + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + + if use_fp8: + init_dtype = dtype + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + else: + w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype + ) + w1_scale = w2_scale = a1_scale = a2_scale = None + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + + # Warmup + api_func = ( + fused_moe_vllm_api + if provider == "vllm_fused_moe_triton" + else fused_moe_sglang_api + ) + for _ in range(10): + y = api_func( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + torch.cuda.synchronize() + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: api_func( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + )[0], + quantiles=quantiles, + ) + return ms, min_ms, max_ms + + +def main(): + parser = FlexibleArgumentParser() + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument("--tp-size", type=int, default=2) + parser.add_argument("--use-fp8", action="store_true") + parser.add_argument( + "--save-path", + type=str, + default="./configs/benchmark_ops/vllm_sglang_fused_moe/", + ) + args = parser.parse_args() + + model_config = get_model_config(args.model, args.tp_size) + benchmark.run( + show_plots=True, + print_data=True, + save_path=args.save_path, + model_config=model_config, + use_fp8=args.use_fp8, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py new file mode 100644 index 00000000000..9b232264ab3 --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -0,0 +1,446 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py +import argparse +import time +from datetime import datetime +from typing import Any, Dict, List, Tuple, TypedDict + +import ray +import torch +import triton +from ray.experimental.tqdm_ray import tqdm +from transformers import AutoConfig +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser + +from sglang.srt.layers.fused_moe_triton.fused_moe import * + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, +) -> float: + init_dtype = torch.float16 if use_fp8_w8a8 else dtype + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + if use_int8_w8a16: + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + ) + else: + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_int8_w8a16: + w1_scale = torch.randn( + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + ) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_fp8_w8a8: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + from sglang.srt.layers.fused_moe_triton.fused_moe import override_config + + with override_config(config): + fused_moe( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: List[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def get_configs_compute_bound() -> List[Dict[str, int]]: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + configs: List[BenchmarkConfig] = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(seed) + self.seed = seed + + def benchmark( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + ) -> Tuple[Dict[str, int], float]: + current_platform.seed_everything(self.seed) + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + op_config = get_moe_configs( + num_experts, shard_intermediate_size // 2, dtype_str + ) + if op_config is None: + config = get_default_config( + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype_str, + ) + else: + config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + ) + return config, kernel_time + + def tune( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + search_space: List[Dict[str, int]], + ) -> Dict[str, int]: + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + assert best_config is not None + return best_config + + +def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: + return { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + } + + +def save_configs( + configs: Dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, +) -> None: + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) + + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str + ) + + print(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def main(args: argparse.Namespace): + print(args) + + config = AutoConfig.from_pretrained(args.model) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] == "Qwen2MoeForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + else: + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + + hidden_size = config.hidden_size + dtype = config.torch_dtype + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" + + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + else: + batch_sizes = [args.batch_size] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: List[Any]) -> List[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + if args.tune: + search_space = get_configs_compute_bound() + print(f"Start tuning over {len(search_space)} configurations...") + + start = time.time() + configs = _distribute( + "tune", + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + search_space, + ) + for batch_size in batch_sizes + ], + ) + best_configs = { + M: sort_config(config) for M, config in zip(batch_sizes, configs) + } + save_configs( + best_configs, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + ) + end = time.time() + print(f"Tuning took {end - start:.2f} seconds") + else: + outputs = _distribute( + "benchmark", + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + ) + for batch_size in batch_sizes + ], + ) + + for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): + print(f"Batch size: {batch_size}, config: {config}") + print(f"Kernel time: {kernel_time:.2f} us") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument("--tp-size", "-tp", type=int, default=2) + parser.add_argument( + "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--tune", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 144ade58ea7..dfa68fd959e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -169,10 +169,11 @@ def __post_init__(self): gpu_mem = get_amdgpu_memory_capacity() else: gpu_mem = get_nvgpu_memory_capacity() + if gpu_mem < 25000: - self.chunked_prefill_size //= 4 # make it 2048 - self.cuda_graph_max_bs = 4 - logger.info("Automatically adjust --chunked-prefill-size for small GPUs.") + logger.warning( + "Your GPU has less than 25GB memory. You may want to set a smaller --chunked-prefill-size (e.g., 512) to improve performance." + ) # Choose kernel backends if not is_flashinfer_available():