-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
996357e
commit 0bbd8c8
Showing
2 changed files
with
52 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
def fused_moe( | ||
hidden_states: torch.Tensor, | ||
w1: torch.Tensor, | ||
w2: torch.Tensor, | ||
gating_output: torch.Tensor, | ||
topk: int, | ||
renormalize: bool, | ||
) -> torch.Tensor: | ||
""" | ||
Args: | ||
hidden_states: [*, hidden_size] | ||
w1: [num_experts, intermediate_size * 2, hidden_size] | ||
w2: [num_experts, hidden_size, intermediate_size] | ||
gating_output: [*, num_experts] | ||
""" | ||
orig_shape = hidden_states.shape | ||
hidden_size = hidden_states.shape[-1] | ||
num_tokens = hidden_states.shape[:-1].numel() | ||
num_experts = w1.shape[0] | ||
intermediate_size = w2.shape[-1] | ||
device = hidden_states.device | ||
dtype = hidden_states.dtype | ||
|
||
hidden_states = hidden_states.view(num_tokens, hidden_size) | ||
gating_output = gating_output.view(num_tokens, num_experts) | ||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) | ||
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) | ||
if renormalize: | ||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) | ||
topk_weights = topk_weights.to(dtype) | ||
|
||
final_hidden_states = None | ||
for expert_idx in range(num_experts): | ||
expert_w1 = w1[expert_idx] | ||
expert_w2 = w2[expert_idx] | ||
expert_mask = (selected_experts == expert_idx) | ||
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) | ||
x = F.linear(hidden_states, expert_w1) | ||
gate = F.silu(x[:, :intermediate_size]) | ||
x = x[:, intermediate_size:] * gate | ||
x = F.linear(x, expert_w2) | ||
current_hidden_states = x * expert_weights | ||
if final_hidden_states is None: | ||
final_hidden_states = current_hidden_states | ||
else: | ||
final_hidden_states = final_hidden_states + current_hidden_states | ||
|
||
return final_hidden_states.view(orig_shape) | ||
Check failure on line 51 in vllm/model_executor/layers/fused_moe/moe_torch_iterative.py GitHub Actions / mypy (3.9)
Check failure on line 51 in vllm/model_executor/layers/fused_moe/moe_torch_iterative.py GitHub Actions / mypy (3.10)
Check failure on line 51 in vllm/model_executor/layers/fused_moe/moe_torch_iterative.py GitHub Actions / mypy (3.11)
|