From cd716aa3e9c9b3ca8204e9f1c2aa40523c18fa6c Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 6 Jan 2025 10:10:46 +0200 Subject: [PATCH 1/5] fix moe on tpu Signed-off-by: Avshalom --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/moe_torch_iterative.py | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/layers/fused_moe/moe_torch_iterative.py diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b108cbd52c218..c45e96a3cd63a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -19,7 +19,7 @@ else: fused_experts = None # type: ignore if current_platform.is_tpu(): - from .moe_pallas import fused_moe as fused_moe_pallas + from .moe_torch_iterative import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py new file mode 100644 index 0000000000000..c60591beb3d6f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -0,0 +1,51 @@ +import torch +import torch.nn.functional as F + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + device = hidden_states.device + dtype = hidden_states.dtype + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + final_hidden_states = None + for expert_idx in range(num_experts): + expert_w1 = w1[expert_idx] + expert_w2 = w2[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) + x = F.linear(hidden_states, expert_w1) + gate = F.silu(x[:, :intermediate_size]) + x = x[:, intermediate_size:] * gate + x = F.linear(x, expert_w2) + current_hidden_states = x * expert_weights + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states = final_hidden_states + current_hidden_states + + return final_hidden_states.view(orig_shape) From bdb6d39f9cca5610de206d544d444198d0f0f24d Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 6 Jan 2025 10:18:49 +0200 Subject: [PATCH 2/5] fix lint Signed-off-by: Avshalom --- vllm/model_executor/layers/fused_moe/moe_torch_iterative.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py index c60591beb3d6f..481d2f3329e20 100644 --- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F + def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -21,7 +22,6 @@ def fused_moe( num_tokens = hidden_states.shape[:-1].numel() num_experts = w1.shape[0] intermediate_size = w2.shape[-1] - device = hidden_states.device dtype = hidden_states.dtype hidden_states = hidden_states.view(num_tokens, hidden_size) @@ -31,7 +31,7 @@ def fused_moe( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(dtype) - + final_hidden_states = None for expert_idx in range(num_experts): expert_w1 = w1[expert_idx] From f6571a41f82eee411ac76259f22321611557408b Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 6 Jan 2025 10:26:15 +0200 Subject: [PATCH 3/5] mypy Signed-off-by: Avshalom --- vllm/model_executor/layers/fused_moe/moe_torch_iterative.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py index 481d2f3329e20..bcff55f4fdf16 100644 --- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -48,4 +48,4 @@ def fused_moe( else: final_hidden_states = final_hidden_states + current_hidden_states - return final_hidden_states.view(orig_shape) + return final_hidden_states.view(orig_shape) # type: ignore From 9c1511784080a1aaf41f0254d29ad23b4f152ccc Mon Sep 17 00:00:00 2001 From: Avshalom Date: Wed, 8 Jan 2025 08:34:13 +0200 Subject: [PATCH 4/5] add tests Signed-off-by: Avshalom --- tests/kernels/test_moe.py | 7 +++++++ vllm/model_executor/layers/fused_moe/layer.py | 1 + 2 files changed, 8 insertions(+) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 8b23b62826053..e0db5129340a3 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -14,6 +14,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -46,6 +48,11 @@ def test_fused_moe( triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + iterative_output = iterative_moe(a, w1, w2, score, topk) + torch.testing.assert_close(iterative_output, + torch_output, + atol=2e-2, + rtol=0) @pytest.mark.parametrize("dtype", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c45e96a3cd63a..ca8d0bcabd156 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -19,6 +19,7 @@ else: fused_experts = None # type: ignore if current_platform.is_tpu(): + # the iterative moe implementation is used until the moe_pallas is fixed from .moe_torch_iterative import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore From 41857109a874107ab16826ef6f0703e368d60842 Mon Sep 17 00:00:00 2001 From: Avshalom Date: Sun, 12 Jan 2025 10:21:54 +0200 Subject: [PATCH 5/5] fixing test Signed-off-by: Avshalom --- tests/kernels/test_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index e0db5129340a3..7fa5de1984452 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -48,7 +48,7 @@ def test_fused_moe( triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - iterative_output = iterative_moe(a, w1, w2, score, topk) + iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False) torch.testing.assert_close(iterative_output, torch_output, atol=2e-2,