Skip to content

Commit

Permalink
[Minor] Fix grok model loader (#2473)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Dec 12, 2024
1 parent f0ed9c3 commit 5282a47
Showing 1 changed file with 72 additions and 8 deletions.
80 changes: 72 additions & 8 deletions python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
Expand All @@ -40,10 +42,43 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader


class Grok1MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results=True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
reduce_results=reduce_results,
)
self.act_fn = GeluAndMul(approximate="tanh")

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class Grok1MoE(nn.Module):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
across all ranks.
Expand All @@ -55,13 +90,15 @@ class Grok1MoE(nn.Module):

def __init__(
self,
config: PretrainedConfig,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
reduce_results=True,
):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -75,13 +112,16 @@ def __init__(
quant_config=None,
)

self.router_logit_softcapping = getattr(
config, "router_logit_softcapping", 30.0
)
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
reduce_results=reduce_results,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
Expand All @@ -91,26 +131,31 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = 30.0 * F.tanh(router_logits / 30.0)

# need to assert self.gate.quant_method is unquantized
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)


class Grok1Attention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
logit_cap: float = 30,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.layer_id = layer_id
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
Expand All @@ -126,7 +171,7 @@ def __init__(
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = 128
self.head_dim = getattr(config, "head_dim", 128)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
Expand All @@ -140,7 +185,6 @@ def __init__(
bias=False,
quant_config=quant_config,
)

self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
Expand All @@ -154,6 +198,9 @@ def __init__(
base=int(self.rope_theta),
is_neox_style=True,
)

logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)

self.attn = RadixAttention(
self.num_heads,
self.head_dim,
Expand All @@ -162,7 +209,6 @@ def __init__(
layer_id=layer_id,
logit_cap=logit_cap,
)
# TODO(lianmin): load logit cap from config

def forward(
self,
Expand All @@ -186,10 +232,12 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size

rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = Grok1Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
Expand All @@ -199,11 +247,17 @@ def __init__(
quant_config=quant_config,
)
self.block_sparse_moe = Grok1MoE(
config=config,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "intermediate_size", None),
),
quant_config=quant_config,
reduce_results=True,
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -284,6 +338,7 @@ def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config
Expand All @@ -310,6 +365,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
Expand Down Expand Up @@ -345,6 +402,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
name = name.replace(weight_name, param_name)

if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
Expand All @@ -357,7 +419,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
# Skip loading kv_scale from ckpts towards new design.
if name.endswith(".kv_scale") and name not in params_dict:
Expand Down

0 comments on commit 5282a47

Please sign in to comment.