Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Avshalom <[email protected]>
  • Loading branch information
avshalomman committed Jan 8, 2025
1 parent 3a2f669 commit 83b6002
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
Expand Down Expand Up @@ -46,6 +48,11 @@ def test_fused_moe(
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a, w1, w2, score, topk)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
rtol=0)


@pytest.mark.parametrize("dtype",
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
Expand Down

0 comments on commit 83b6002

Please sign in to comment.