Skip to content

Commit

Permalink
fix moe on tpu
Browse files Browse the repository at this point in the history
  • Loading branch information
avshalomman committed Jan 6, 2025
1 parent 996357e commit 0bbd8c8
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
from .moe_torch_iterative import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
Expand Down
51 changes: 51 additions & 0 deletions vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn.functional as F

def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
device = hidden_states.device

Check failure on line 24 in vllm/model_executor/layers/fused_moe/moe_torch_iterative.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/model_executor/layers/fused_moe/moe_torch_iterative.py:24:5: F841 Local variable `device` is assigned to but never used
dtype = hidden_states.dtype

hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)

final_hidden_states = None
for expert_idx in range(num_experts):
expert_w1 = w1[expert_idx]
expert_w2 = w2[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
x = F.linear(hidden_states, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states = final_hidden_states + current_hidden_states

return final_hidden_states.view(orig_shape)

Check failure on line 51 in vllm/model_executor/layers/fused_moe/moe_torch_iterative.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "view" [union-attr]

Check failure on line 51 in vllm/model_executor/layers/fused_moe/moe_torch_iterative.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "view" [union-attr]

Check failure on line 51 in vllm/model_executor/layers/fused_moe/moe_torch_iterative.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "view" [union-attr]

Check failure on line 51 in vllm/model_executor/layers/fused_moe/moe_torch_iterative.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "view" [union-attr]

0 comments on commit 0bbd8c8

Please sign in to comment.