diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 23cc6e8539431..7710547c1f114 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,12 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +# llm = LLM(model="facebook/opt-125m") +llm = LLM(model="/data/xmo/vllm/deepseekv3-lite-base-latest", + tokenizer="/data/xmo/whales", + trust_remote_code=True, + enforce_eager=True, + max_model_len=1024) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -19,4 +24,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/config.py b/vllm/config.py index 17602bda15c69..fbd553e8b9f28 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -160,7 +160,7 @@ class ModelConfig: override default pooling config for the pooling model. logits_processor_pattern: Optional regex pattern specifying valid logits processor qualified names that can be passed with the - `logits_processors` extra completion argument. Defaults to None, + `logits_processors` extra completion argument. Defaults to None, which allows no processors. generation_config: Configuration parameter file for generation. """ @@ -363,7 +363,7 @@ def __init__(self, def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: """ - Pull the model config or tokenizer to a temporary + Pull the model config or tokenizer to a temporary directory in case of S3. Args: @@ -721,7 +721,7 @@ def get_hidden_size(self) -> int: def get_head_size(self) -> int: # TODO remove hard code if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type == 'deepseek_v2': + ) and (self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3')): # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 @@ -874,14 +874,14 @@ def try_get_generation_config(self) -> Dict[str, Any]: def get_diff_sampling_param(self) -> Dict[str, Any]: """ - This method returns a dictionary containing the parameters - that differ from the default sampling parameters, but only - if `generation_config` is set. If `generation_config` is not + This method returns a dictionary containing the parameters + that differ from the default sampling parameters, but only + if `generation_config` is set. If `generation_config` is not set, an empty dictionary is returned. Returns: - Dict[str, Any]: A dictionary with the differing sampling - parameters if `generation_config` is set, otherwise an + Dict[str, Any]: A dictionary with the differing sampling + parameters if `generation_config` is set, otherwise an empty dictionary. """ if self.generation_config is None: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f4541f66d6d0b..7aa64bee2faba 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -421,14 +421,15 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model def grouped_topk(hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, - scoring_func: str = "softmax"): + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -440,6 +441,9 @@ def grouped_topk(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported scoring function: {scoring_func}") + if e_score_correction_bias is not None: + scores.add_(e_score_correction_bias.unsqueeze(0)) + num_token = scores.shape[0] group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6f89c57e9ce72..45d4d6fd52526 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -81,7 +81,9 @@ def apply( use_grouped_topk: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None ) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -91,7 +93,9 @@ def apply( use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) def forward_cuda( self, @@ -103,7 +107,9 @@ def forward_cuda( renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -113,7 +119,9 @@ def forward_cuda( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -136,7 +144,8 @@ def forward_tpu( renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None + custom_routing_function: Optional[Callable] = None, + **kwargs, ) -> torch.Tensor: assert not use_grouped_topk assert num_expert_group is None @@ -190,6 +199,7 @@ def __init__( prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): super().__init__() @@ -210,9 +220,12 @@ def __init__( self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError("Only softmax scoring function is supported for non-grouped topk.") + raise ValueError( + "Only softmax scoring function is supported for non-grouped topk." + ) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -447,7 +460,8 @@ def select_experts(hidden_states: torch.Tensor, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax"): + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None): from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, grouped_topk) @@ -462,7 +476,8 @@ def select_experts(hidden_states: torch.Tensor, renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group, - scoring_func=scoring_func) + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, gating_output=router_logits, @@ -491,7 +506,9 @@ def forward(self, hidden_states: torch.Tensor, use_grouped_topk=self.use_grouped_topk, topk_group=self.topk_group, num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function) + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py index 31295ddc7fa62..1507c1d101922 100644 --- a/vllm/model_executor/models/deepseek_v3.py +++ b/vllm/model_executor/models/deepseek_v3.py @@ -90,6 +90,28 @@ def forward(self, x): return x +class MoEGate(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + ): + super().__init__() + # TODO(simon): make this replicated linear + self.weight = nn.Parameter( + torch.empty(config.n_routed_experts, config.hidden_size)) + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + torch.empty((config.n_routed_experts))) + else: + self.e_score_correction_bias = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.linear(hidden_states, + self.weight, + bias=None) + + class DeepseekV3MoE(nn.Module): def __init__( @@ -112,24 +134,22 @@ def __init__( raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") - self.experts = FusedMoE(num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func) - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = MoEGate(config) + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) @@ -147,7 +167,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + router_logits = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor @@ -244,8 +264,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj") - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' + rope_scaling["rope_type"] = 'deepseek_yarn' self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, @@ -624,6 +643,9 @@ def load_weights(self, weights: Iterable[Tuple[str, if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + for key in params_dict: + print(key) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)