From 63d6c0e7fad5d05eec9127194af123251ba5ffae Mon Sep 17 00:00:00 2001
From: Yida Wu <yida.wu@amd.com>
Date: Mon, 6 Jan 2025 22:25:00 +0000
Subject: [PATCH] deepseek overflow fix

---
 vllm/model_executor/models/deepseek_v2.py | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 4cf4e6c358bf2..d63bfc72e1a26 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -147,11 +147,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
             shared_output = self.shared_experts(hidden_states)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
-        final_hidden_states = self.experts(
-            hidden_states=hidden_states,
-            router_logits=router_logits) * self.routed_scaling_factor
+        final_hidden_states = self.experts(hidden_states=hidden_states,
+                                           router_logits=router_logits)
         if shared_output is not None:
-            final_hidden_states = final_hidden_states + shared_output
+            final_hidden_states = final_hidden_states + shared_output \
+                * (1. / self.routed_scaling_factor)
         if self.tp_size > 1:
             final_hidden_states = tensor_model_parallel_all_reduce(
                 final_hidden_states)
@@ -375,6 +375,7 @@ def __init__(
                                        eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                 eps=config.rms_norm_eps)
+        self.routed_scaling_factor = config.routed_scaling_factor
 
     def forward(
         self,
@@ -399,9 +400,14 @@ def forward(
         )
 
         # Fully Connected
+        if isinstance(self.mlp, DeepseekV2MoE):
+            hidden_states *= 1. / self.mlp.routed_scaling_factor
         hidden_states, residual = self.post_attention_layernorm(
             hidden_states, residual)
         hidden_states = self.mlp(hidden_states)
+        if isinstance(self.mlp, DeepseekV2MLP):
+            hidden_states *= 1. / self.routed_scaling_factor
+            residual *= 1. / self.routed_scaling_factor
         return hidden_states, residual