diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 3e2078c4a4d..48b733fdb78 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.layers.quantization.fp8 import Fp8Config QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] -def fp8_moe_apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, -) -> torch.Tensor: - """Enhanced apply method for FP8 MoE.""" - from sglang.srt.layers.fused_moe_triton import FusedMoE - from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts - - # Expert selection - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - ) - - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - def fp8_get_quant_method(self, layer, prefix): """Enhanced get_quant_method for FP8 config.""" from vllm.model_executor.layers.linear import LinearBase @@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix): from sglang.srt.layers.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.linear import UnquantizedLinearMethod - from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod + from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): @@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix): def apply_monkey_patches(): """Apply all monkey patches in one place.""" - setattr(Fp8MoEMethod, "apply", fp8_moe_apply) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index acdce0b8cbd..0e3c7abd924 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -24,11 +24,6 @@ ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter -from sglang.srt.layers.fused_moe_triton import ( - FusedMoE, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, -) from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -100,6 +95,8 @@ def get_quant_method( ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + from sglang.srt.layers.fused_moe_triton import FusedMoE + if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() @@ -306,7 +303,7 @@ def apply( ) -class Fp8MoEMethod(FusedMoEMethodBase): +class Fp8MoEMethod: """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config): + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): self.quant_config = quant_config def create_weights( @@ -331,6 +346,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn @@ -521,8 +537,8 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe import fused_experts + from sglang.srt.layers.fused_moe_triton import FusedMoE + from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x,