Skip to content

Commit

Permalink
BF16 + TP1 working
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo committed Dec 26, 2024
1 parent f01c04a commit a5239d5
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 42 deletions.
9 changes: 7 additions & 2 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="/data/xmo/vllm/deepseekv3-lite-base-latest",
tokenizer="/data/xmo/whales",
trust_remote_code=True,
enforce_eager=True,
max_model_len=1024)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
16 changes: 8 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class ModelConfig:
override default pooling config for the pooling model.
logits_processor_pattern: Optional regex pattern specifying valid
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
"""
Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(self,
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""
Pull the model config or tokenizer to a temporary
Pull the model config or tokenizer to a temporary
directory in case of S3.
Args:
Expand Down Expand Up @@ -721,7 +721,7 @@ def get_hidden_size(self) -> int:
def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
) and (self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3')):

Check failure on line 724 in vllm/config.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/config.py:724:81: E501 Line too long (92 > 80)
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
Expand Down Expand Up @@ -874,14 +874,14 @@ def try_get_generation_config(self) -> Dict[str, Any]:

def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
Returns:
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
if self.generation_config is None:
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,15 @@ def fused_topk(
return topk_weights, topk_ids


# This is used by the Deepseek-V2 model
# This is used by the Deepseek-V2 and Deepseek-V3 model
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax"):
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):

assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
Expand All @@ -440,6 +441,9 @@ def grouped_topk(hidden_states: torch.Tensor,
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")

if e_score_correction_bias is not None:
scores.add_(e_score_correction_bias.unsqueeze(0))

num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
Expand Down
35 changes: 26 additions & 9 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def apply(
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
Expand All @@ -91,7 +93,9 @@ def apply(
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

def forward_cuda(
self,
Expand All @@ -103,7 +107,9 @@ def forward_cuda(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -113,7 +119,9 @@ def forward_cuda(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand All @@ -136,7 +144,8 @@ def forward_tpu(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
custom_routing_function: Optional[Callable] = None,
**kwargs,
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
Expand Down Expand Up @@ -190,6 +199,7 @@ def __init__(
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
super().__init__()

Expand All @@ -210,9 +220,12 @@ def __init__(
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for non-grouped topk.")
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."

Check failure on line 227 in vllm/model_executor/layers/fused_moe/layer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/layers/fused_moe/layer.py:227:81: E501 Line too long (82 > 80)
)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down Expand Up @@ -447,7 +460,8 @@ def select_experts(hidden_states: torch.Tensor,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax"):
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)

Expand All @@ -462,7 +476,8 @@ def select_experts(hidden_states: torch.Tensor,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func)
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
Expand Down Expand Up @@ -491,7 +506,9 @@ def forward(self, hidden_states: torch.Tensor,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias)

if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
Expand Down
64 changes: 43 additions & 21 deletions vllm/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,28 @@ def forward(self, x):
return x


class MoEGate(nn.Module):

def __init__(
self,
config: PretrainedConfig,
):
super().__init__()
# TODO(simon): make this replicated linear
self.weight = nn.Parameter(
torch.empty(config.n_routed_experts, config.hidden_size))
if config.topk_method == "noaux_tc":
self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts)))

Check failure on line 105 in vllm/model_executor/models/deepseek_v3.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (UP034)

vllm/model_executor/models/deepseek_v3.py:105:29: UP034 Avoid extraneous parentheses
else:
self.e_score_correction_bias = None

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(hidden_states,
self.weight,
bias=None)


class DeepseekV3MoE(nn.Module):

def __init__(
Expand All @@ -112,24 +134,22 @@ def __init__(
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")

self.experts = FusedMoE(num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func)

self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
self.gate = MoEGate(config)
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias)

if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
Expand All @@ -147,7 +167,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
Expand Down Expand Up @@ -244,8 +264,7 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down Expand Up @@ -624,6 +643,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
if is_pp_missing_parameter(name, self):
continue

if name not in params_dict:
for key in params_dict:
print(key)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit a5239d5

Please sign in to comment.