Skip to content

Commit

Permalink
fix: resolve fp8 moe issue (#2387)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Dec 7, 2024
1 parent c36736c commit d332aa3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 56 deletions.
49 changes: 2 additions & 47 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8 import Fp8Config

QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
Expand Down Expand Up @@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
return QUANTIZATION_METHODS[quantization]


def fp8_moe_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
"""Enhanced apply method for FP8 MoE."""
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts

# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
)

# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)


def fp8_get_quant_method(self, layer, prefix):
"""Enhanced get_quant_method for FP8 config."""
from vllm.model_executor.layers.linear import LinearBase
Expand All @@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix):

from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
Expand Down Expand Up @@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):

def apply_monkey_patches():
"""Apply all monkey patches in one place."""
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
Expand Down
34 changes: 25 additions & 9 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter

from sglang.srt.layers.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -100,6 +95,8 @@ def get_quant_method(
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import

from sglang.srt.layers.fused_moe_triton import FusedMoE

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
Expand Down Expand Up @@ -306,7 +303,7 @@ def apply(
)


class Fp8MoEMethod(FusedMoEMethodBase):
class Fp8MoEMethod:
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Expand All @@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config.
"""

def __init__(self, quant_config: Fp8Config):
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase

if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)

def __init__(self, quant_config):
self.quant_config = quant_config

def create_weights(
Expand All @@ -331,6 +346,7 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported

if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
Expand Down Expand Up @@ -521,8 +537,8 @@ def apply(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:

from vllm.model_executor.layers.fused_moe import fused_experts
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand Down

0 comments on commit d332aa3

Please sign in to comment.