Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: jiang1.li <[email protected]>
  • Loading branch information
bigPYJ1151 committed Jan 9, 2025
1 parent 707e345 commit 46621ac
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 4 additions & 0 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
),
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param(
"ehristoforu/Falcon3-MoE-2x7B-Insruct", # mixtral
marks=[pytest.mark.cpu_model],
)
])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
Expand Down
15 changes: 11 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum

if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
Expand Down Expand Up @@ -87,10 +88,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

if current_platform.is_cpu():
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(layer.w13_weight,
layer.w2_weight,
use_prepack=True)
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
)
else:
raise NotImplementedError("CPU MOE only supports x86 arch.")

def apply(
self,
Expand Down Expand Up @@ -164,6 +170,7 @@ def forward_cpu(
custom_routing_function: Optional[Callable] = None,
**kwargs,
):
assert custom_routing_function is None
return layer.ipex_fusion(
x,
use_grouped_topk,
Expand Down

0 comments on commit 46621ac

Please sign in to comment.