Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use sgl-kernel moe_align_block_size #2581

Merged
merged 1 commit into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"orjson", "outlines>=0.0.44,<0.1.0",
"packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "gemlite", "uvicorn", "uvloop",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"xgrammar>=0.1.6"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel"]

# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
Expand Down
23 changes: 20 additions & 3 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops

from sglang.srt.layers.moe.topk import select_experts
Expand Down Expand Up @@ -266,9 +267,25 @@ def moe_align_block_size(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
ops.moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
)
# FIXME(zhyncs)
if num_experts >= 256:
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
return sorted_ids, expert_ids, num_tokens_post_pad


Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def __init__(
):
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
# FIXME(HandH1998)
if (
"DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
and not self.server_args.disable_cuda_graph
):
self.server_args.disable_cuda_graph = True

if self.server_args.enable_double_sparsity:
logger.info(
Expand Down
Loading