-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deepseek v3 #11502
Merged
+887
−61
Merged
Deepseek v3 #11502
Changes from 13 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
f01c04a
Initial V3 file and sigmoid change
simon-mo 1afe830
BF16 + TP1 working
simon-mo eea0049
Support deepseek_v3 w8a8 fp8 block-wise quantization
mgoin 083d904
Format
mgoin e71e6aa
Format
mgoin 0807252
Fix yapf
mgoin addc2ed
Merge branch 'deepseek_v3-fp8-support' of github.com:vllm-project/vll…
simon-mo 3c96d48
fp8 plumbing
simon-mo 2f53441
e2e working for base model
simon-mo 2b57c55
Merge branch 'main' of github.com:vllm-project/vllm into deepseek-v3
simon-mo a035d02
format
simon-mo bcd044f
restore some small diff
simon-mo 71e6b03
nit for TPU
simon-mo a8c7457
nits
simon-mo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,16 +73,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, | |
set_weight_attrs(w2_weight, extra_weight_attrs) | ||
|
||
def apply( | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
router_logits: torch.Tensor, | ||
top_k: int, | ||
renormalize: bool, | ||
use_grouped_topk: bool, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
custom_routing_function: Optional[Callable] = None | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
router_logits: torch.Tensor, | ||
top_k: int, | ||
renormalize: bool, | ||
use_grouped_topk: bool, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
custom_routing_function: Optional[Callable] = None, | ||
scoring_func: str = "softmax", | ||
e_score_correction_bias: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
return self.forward(x=x, | ||
layer=layer, | ||
|
@@ -92,19 +94,23 @@ def apply( | |
use_grouped_topk=use_grouped_topk, | ||
topk_group=topk_group, | ||
num_expert_group=num_expert_group, | ||
custom_routing_function=custom_routing_function) | ||
custom_routing_function=custom_routing_function, | ||
scoring_func=scoring_func, | ||
e_score_correction_bias=e_score_correction_bias) | ||
|
||
def forward_cuda( | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
use_grouped_topk: bool, | ||
top_k: int, | ||
router_logits: torch.Tensor, | ||
renormalize: bool, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
custom_routing_function: Optional[Callable] = None | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
use_grouped_topk: bool, | ||
top_k: int, | ||
router_logits: torch.Tensor, | ||
renormalize: bool, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
custom_routing_function: Optional[Callable] = None, | ||
scoring_func: str = "softmax", | ||
e_score_correction_bias: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
topk_weights, topk_ids = FusedMoE.select_experts( | ||
hidden_states=x, | ||
|
@@ -114,7 +120,9 @@ def forward_cuda( | |
renormalize=renormalize, | ||
topk_group=topk_group, | ||
num_expert_group=num_expert_group, | ||
custom_routing_function=custom_routing_function) | ||
custom_routing_function=custom_routing_function, | ||
scoring_func=scoring_func, | ||
e_score_correction_bias=e_score_correction_bias) | ||
|
||
return fused_experts(hidden_states=x, | ||
w1=layer.w13_weight, | ||
|
@@ -128,21 +136,29 @@ def forward_cpu(self, *args, **kwargs): | |
"The CPU backend currently does not support MoE.") | ||
|
||
def forward_tpu( | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
use_grouped_topk: bool, | ||
top_k: int, | ||
router_logits: torch.Tensor, | ||
renormalize: bool, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
custom_routing_function: Optional[Callable] = None | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
use_grouped_topk: bool, | ||
top_k: int, | ||
router_logits: torch.Tensor, | ||
renormalize: bool, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
custom_routing_function: Optional[Callable] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should have the same signature as |
||
scoring_func: str = "softmax", | ||
e_score_correction_bias: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
assert not use_grouped_topk | ||
assert num_expert_group is None | ||
assert topk_group is None | ||
assert custom_routing_function is None | ||
if scoring_func != "softmax": | ||
raise NotImplementedError( | ||
"Only softmax scoring function is supported for TPU.") | ||
if e_score_correction_bias is not None: | ||
raise NotImplementedError( | ||
"Expert score correction bias is not supported for TPU.") | ||
return fused_moe_pallas(hidden_states=x, | ||
w1=layer.w13_weight, | ||
w2=layer.w2_weight, | ||
|
@@ -156,7 +172,7 @@ def forward_tpu( | |
class FusedMoE(torch.nn.Module): | ||
"""FusedMoE layer for MoE models. | ||
This layer contains both MergedColumnParallel weights (gate_up_proj / | ||
This layer contains both MergedColumnParallel weights (gate_up_proj / | ||
w13) and RowParallelLinear weights (down_proj/ w2). | ||
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We | ||
|
@@ -190,6 +206,8 @@ def __init__( | |
tp_size: Optional[int] = None, | ||
prefix: str = "", | ||
custom_routing_function: Optional[Callable] = None, | ||
scoring_func: str = "softmax", | ||
e_score_correction_bias: Optional[torch.Tensor] = None, | ||
): | ||
super().__init__() | ||
|
||
|
@@ -210,6 +228,12 @@ def __init__( | |
self.num_expert_group = num_expert_group | ||
self.topk_group = topk_group | ||
self.custom_routing_function = custom_routing_function | ||
self.scoring_func = scoring_func | ||
self.e_score_correction_bias = e_score_correction_bias | ||
|
||
if self.scoring_func != "softmax" and not self.use_grouped_topk: | ||
raise ValueError("Only softmax scoring function is supported for " | ||
"non-grouped topk.") | ||
|
||
if quant_config is None: | ||
self.quant_method: Optional[QuantizeMethodBase] = ( | ||
|
@@ -446,7 +470,9 @@ def select_experts(hidden_states: torch.Tensor, | |
renormalize: bool, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
custom_routing_function: Optional[Callable] = None): | ||
custom_routing_function: Optional[Callable] = None, | ||
scoring_func: str = "softmax", | ||
e_score_correction_bias: Optional[torch.Tensor] = None): | ||
from vllm.model_executor.layers.fused_moe.fused_moe import ( | ||
fused_topk, grouped_topk) | ||
|
||
|
@@ -460,7 +486,9 @@ def select_experts(hidden_states: torch.Tensor, | |
topk=top_k, | ||
renormalize=renormalize, | ||
num_expert_group=num_expert_group, | ||
topk_group=topk_group) | ||
topk_group=topk_group, | ||
scoring_func=scoring_func, | ||
e_score_correction_bias=e_score_correction_bias) | ||
elif custom_routing_function is None: | ||
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, | ||
gating_output=router_logits, | ||
|
@@ -489,7 +517,9 @@ def forward(self, hidden_states: torch.Tensor, | |
use_grouped_topk=self.use_grouped_topk, | ||
topk_group=self.topk_group, | ||
num_expert_group=self.num_expert_group, | ||
custom_routing_function=self.custom_routing_function) | ||
custom_routing_function=self.custom_routing_function, | ||
scoring_func=self.scoring_func, | ||
e_score_correction_bias=self.e_score_correction_bias) | ||
|
||
if self.reduce_results and self.tp_size > 1: | ||
final_hidden_states = tensor_model_parallel_all_reduce( | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These need to be added to all methods that use MoE (gptq, compressed-tensors, etc)