Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Fix grok model loader #2473

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 72 additions & 8 deletions python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
Expand All @@ -40,10 +42,43 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader


class Grok1MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results=True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
reduce_results=reduce_results,
)
self.act_fn = GeluAndMul(approximate="tanh")

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class Grok1MoE(nn.Module):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
across all ranks.
Expand All @@ -55,13 +90,15 @@ class Grok1MoE(nn.Module):

def __init__(
self,
config: PretrainedConfig,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
reduce_results=True,
):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -75,13 +112,16 @@ def __init__(
quant_config=None,
)

self.router_logit_softcapping = getattr(
config, "router_logit_softcapping", 30.0
)
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
reduce_results=reduce_results,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
Expand All @@ -91,26 +131,31 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = 30.0 * F.tanh(router_logits / 30.0)

# need to assert self.gate.quant_method is unquantized
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)


class Grok1Attention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
logit_cap: float = 30,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.layer_id = layer_id
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
Expand All @@ -126,7 +171,7 @@ def __init__(
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = 128
self.head_dim = getattr(config, "head_dim", 128)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
Expand All @@ -140,7 +185,6 @@ def __init__(
bias=False,
quant_config=quant_config,
)

self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
Expand All @@ -154,6 +198,9 @@ def __init__(
base=int(self.rope_theta),
is_neox_style=True,
)

logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)

self.attn = RadixAttention(
self.num_heads,
self.head_dim,
Expand All @@ -162,7 +209,6 @@ def __init__(
layer_id=layer_id,
logit_cap=logit_cap,
)
# TODO(lianmin): load logit cap from config

def forward(
self,
Expand All @@ -186,10 +232,12 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size

rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = Grok1Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
Expand All @@ -199,11 +247,17 @@ def __init__(
quant_config=quant_config,
)
self.block_sparse_moe = Grok1MoE(
config=config,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "intermediate_size", None),
),
quant_config=quant_config,
reduce_results=True,
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -284,6 +338,7 @@ def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config
Expand All @@ -310,6 +365,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
Expand Down Expand Up @@ -345,6 +402,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
name = name.replace(weight_name, param_name)

if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
Expand All @@ -357,7 +419,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
# Skip loading kv_scale from ckpts towards new design.
if name.endswith(".kv_scale") and name not in params_dict:
Expand Down
Loading