Skip to content

Commit

Permalink
support moe
Browse files Browse the repository at this point in the history
Signed-off-by: jiang1.li <[email protected]>
  • Loading branch information
bigPYJ1151 committed Jan 8, 2025
1 parent 259abd8 commit ae4e7fc
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

if current_platform.is_cpu():
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(layer.w13_weight,
layer.w2_weight,
use_prepack=True)

def apply(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -142,9 +151,28 @@ def forward_cuda(
topk_ids=topk_ids,
inplace=True)

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
def forward_cpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
**kwargs,
):
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
)

def forward_tpu(
self,
Expand Down

0 comments on commit ae4e7fc

Please sign in to comment.