From 16a4016d0f295432a44f5b0a546c9b7fc35db285 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Wed, 25 Dec 2024 14:27:05 -0800 Subject: [PATCH] use sgl-kernel moe_align_block_size --- python/pyproject.toml | 4 ++-- .../layers/moe/fused_moe_triton/fused_moe.py | 23 ++++++++++++++++--- .../sglang/srt/model_executor/model_runner.py | 6 +++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index d459c523f10..c8ac5cb4418 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,9 +21,9 @@ runtime_common = ["aiohttp", "decord", "fastapi", "orjson", "outlines>=0.0.44,<0.1.0", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", - "pyzmq>=25.1.2", "torchao>=0.7.0", "gemlite", "uvicorn", "uvloop", + "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.6"] -srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"] +srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel"] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index a645e5f7d8f..10856184235 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -11,6 +11,7 @@ import torch import triton import triton.language as tl +from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from vllm import _custom_ops as ops from sglang.srt.layers.moe.topk import select_experts @@ -266,9 +267,25 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + # FIXME(zhyncs) + if num_experts >= 256: + sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2612f8840fa..f98cc14fbd4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -95,6 +95,12 @@ def __init__( ): logger.info("MLA optimization is turned on. Use triton backend.") self.server_args.attention_backend = "triton" + # FIXME(HandH1998) + if ( + "DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures + and not self.server_args.disable_cuda_graph + ): + self.server_args.disable_cuda_graph = True if self.server_args.enable_double_sparsity: logger.info(