Skip to content

Commit

Permalink
[Hardware][CPU] Support MOE models on x86 CPU (vllm-project#11831)
Browse files Browse the repository at this point in the history
Signed-off-by: jiang1.li <[email protected]>
  • Loading branch information
bigPYJ1151 authored and frreiss committed Jan 10, 2025
1 parent 70084ce commit 8872b3e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/getting_started/installation/cpu-x86.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:

- Tensor Parallel
- Model Quantization (`INT8 W8A8, AWQ`)
- Model Quantization (`INT8 W8A8, AWQ, GPTQ`)
- Chunked-prefill
- Prefix-caching
- FP8-E5M2 KV-Caching (TODO)
Expand Down
4 changes: 4 additions & 0 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
),
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param(
"ehristoforu/Falcon3-MoE-2x7B-Insruct", # mixtral
marks=[pytest.mark.cpu_model],
)
])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
Expand Down
41 changes: 38 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum

if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
Expand Down Expand Up @@ -83,6 +84,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
)
else:
raise NotImplementedError("CPU MOE only supports x86 arch.")

def apply(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -142,9 +157,29 @@ def forward_cuda(
topk_ids=topk_ids,
inplace=True)

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
def forward_cpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
**kwargs,
):
assert custom_routing_function is None
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
)

def forward_tpu(
self,
Expand Down

0 comments on commit 8872b3e

Please sign in to comment.