From 83b60027b7b1071295898964e019da1edf3cc4a7 Mon Sep 17 00:00:00 2001 From: Avshalom Date: Wed, 8 Jan 2025 08:34:13 +0200 Subject: [PATCH] add tests Signed-off-by: Avshalom --- tests/kernels/test_moe.py | 7 +++++++ vllm/model_executor/layers/fused_moe/layer.py | 1 + 2 files changed, 8 insertions(+) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 8b23b62826053..e0db5129340a3 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -14,6 +14,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -46,6 +48,11 @@ def test_fused_moe( triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + iterative_output = iterative_moe(a, w1, w2, score, topk) + torch.testing.assert_close(iterative_output, + torch_output, + atol=2e-2, + rtol=0) @pytest.mark.parametrize("dtype", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c45e96a3cd63a..ca8d0bcabd156 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -19,6 +19,7 @@ else: fused_experts = None # type: ignore if current_platform.is_tpu(): + # the iterative moe implementation is used until the moe_pallas is fixed from .moe_torch_iterative import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore