Skip to content

Commit

Permalink
[Hardware][TPU] workaround fix for MoE on TPU (vllm-project#11764)
Browse files Browse the repository at this point in the history
Signed-off-by: Harry Mellor <[email protected]>
  • Loading branch information
avshalomman authored and hmellor committed Jan 12, 2025
1 parent d4fa146 commit dc53376
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
7 changes: 7 additions & 0 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
Expand Down Expand Up @@ -46,6 +48,11 @@ def test_fused_moe(
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
rtol=0)


@pytest.mark.parametrize("dtype",
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
Expand Down
51 changes: 51 additions & 0 deletions vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn.functional as F


def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
dtype = hidden_states.dtype

hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)

final_hidden_states = None
for expert_idx in range(num_experts):
expert_w1 = w1[expert_idx]
expert_w2 = w2[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
x = F.linear(hidden_states, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states = final_hidden_states + current_hidden_states

return final_hidden_states.view(orig_shape) # type: ignore

0 comments on commit dc53376

Please sign in to comment.