From 3d3f5e783a4797206b230c112599b39680225e57 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Thu, 19 Dec 2024 10:39:16 +0000 Subject: [PATCH] support moe --- vllm/model_executor/layers/fused_moe/layer.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8c6f7c6e06515..a4050ca5d535b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -71,6 +71,15 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + if current_platform.is_cpu(): + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(layer.w13_weight, + layer.w2_weight, + use_prepack=True) + def apply( self, layer: torch.nn.Module, @@ -122,9 +131,28 @@ def forward_cuda( topk_ids=topk_ids, inplace=True) - def forward_cpu(self, *args, **kwargs): - raise NotImplementedError( - "The CPU backend currently does not support MoE.") + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + **kwargs, + ): + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + ) def forward_tpu( self,