diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 2b52e2b1bcc..e74a0bc5a54 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -25,9 +25,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, @@ -40,10 +42,43 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader +class Grok1MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + reduce_results=True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + reduce_results=reduce_results, + ) + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + class Grok1MoE(nn.Module): """A tensor-parallel MoE implementation for Grok1 that shards each expert across all ranks. @@ -55,6 +90,7 @@ class Grok1MoE(nn.Module): def __init__( self, + config: PretrainedConfig, num_experts: int, top_k: int, hidden_size: int, @@ -62,6 +98,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + reduce_results=True, ): super().__init__() self.hidden_size = hidden_size @@ -75,13 +112,16 @@ def __init__( quant_config=None, ) + self.router_logit_softcapping = getattr( + config, "router_logit_softcapping", 30.0 + ) self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - reduce_results=True, + reduce_results=reduce_results, renormalize=False, quant_config=quant_config, tp_size=tp_size, @@ -91,9 +131,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) router_logits = 30.0 * F.tanh(router_logits / 30.0) + + # need to assert self.gate.quant_method is unquantized final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) @@ -101,16 +144,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Grok1Attention(nn.Module): def __init__( self, + config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, layer_id: int = 0, max_position: int = 4096 * 32, rope_theta: float = 10000, - logit_cap: float = 30, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + self.config = config + self.layer_id = layer_id self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads @@ -126,7 +171,7 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = 128 + self.head_dim = getattr(config, "head_dim", 128) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -140,7 +185,6 @@ def __init__( bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, @@ -154,6 +198,9 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) + + logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0) + self.attn = RadixAttention( self.num_heads, self.head_dim, @@ -162,7 +209,6 @@ def __init__( layer_id=layer_id, logit_cap=logit_cap, ) - # TODO(lianmin): load logit cap from config def forward( self, @@ -186,10 +232,12 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, @@ -199,11 +247,17 @@ def __init__( quant_config=quant_config, ) self.block_sparse_moe = Grok1MoE( + config=config, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, + intermediate_size=getattr( + config, + "moe_intermediate_size", + getattr(config, "intermediate_size", None), + ), quant_config=quant_config, + reduce_results=True, ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -284,6 +338,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + cache_config=None, ) -> None: super().__init__() self.config = config @@ -310,6 +365,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales @@ -345,6 +402,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -357,7 +419,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip loading kv_scale from ckpts towards new design. if name.endswith(".kv_scale") and name not in params_dict: