From 993956c6b1e4da007c6d821c1d181221410825d2 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Wed, 11 Dec 2024 06:30:23 -0800 Subject: [PATCH] Add support for IBM Granite 3.x models (#2437) --- docs/references/supported_models.md | 1 + python/sglang/lang/chat_template.py | 32 ++ python/sglang/srt/layers/logits_processor.py | 12 +- python/sglang/srt/models/granite.py | 517 +++++++++++++++++++ test/srt/models/test_generation_models.py | 1 + 5 files changed, 562 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/models/granite.py diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index bf1044f8498..9dafc3d2a3d 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -29,6 +29,7 @@ - SmolLM - GLM-4 - Phi-3-Small +- IBM Granite 3 ## Embedding Models diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 3e5ac8dd522..4a774c4fb6b 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -320,6 +320,28 @@ def get_chat_template_by_model_path(model_path): ) ) +register_chat_template( + ChatTemplate( + name="granite-3-instruct", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|start_of_role|>system<|end_of_role|>", + "<|end_of_text|>", + ), + "user": ( + "<|start_of_role|>user<|end_of_role|>", + "<|end_of_text|>", + ), + "assistant": ( + "<|start_of_role|>assistant<|end_of_role|>", + "<|end_of_text|>", + ), + }, + stop_str=("<|end_of_text|>",), + ) +) + @register_chat_template_matching_function def match_dbrx(model_path: str): @@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str): return get_chat_template("c4ai-command-r") +@register_chat_template_matching_function +def match_granite_instruct(model_path: str): + model_path = model_path.lower() + # When future versions of Granite are released, this code may + # need to be updated. For now, assume that the Granite 3.0 + # template works across the board. + if "granite" in model_path and "instruct" in model_path: + return get_chat_template("granite-3-instruct") + + if __name__ == "__main__": messages = [ {"role": "system", "content": None}, # None means default diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 915cb47d271..3d82592496e 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -91,9 +91,12 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): class LogitsProcessor(nn.Module): - def __init__(self, config, skip_all_gather: bool = False): + def __init__( + self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None + ): super().__init__() self.config = config + self.logit_scale = logit_scale self.do_tensor_parallel_all_gather = ( not skip_all_gather and get_tensor_model_parallel_world_size() > 1 ) @@ -240,6 +243,9 @@ def forward( all_logits = self._get_logits(states, lm_head) if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) + + # The LM head's weights may be zero-padded for parallelism. Remove any + # extra logits that this padding may have produced. all_logits = all_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): @@ -302,6 +308,10 @@ def _get_logits( else: # GGUF models logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) + + # Optional scaling factor, backported from vLLM 0.4 + if self.logit_scale is not None: + logits.mul_(self.logit_scale) # In-place multiply return logits diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py new file mode 100644 index 00000000000..d207ff61b26 --- /dev/null +++ b/python/sglang/srt/models/granite.py @@ -0,0 +1,517 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 +"""Inference-only Granite model compatible with HuggingFace weights.""" + +import logging +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import GraniteConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.rotary_embedding import get_rope + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class GraniteMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GraniteAttention(nn.Module): + def __init__( + self, + config: GraniteConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteDecoderLayer(nn.Module): + def __init__( + self, + config: GraniteConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = GraniteAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = GraniteMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = ( + self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + * self.residual_multiplier + ) # multiplier for Maximal Update Parameterization + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) * self.residual_multiplier + return hidden_states, residual + + +class GraniteModel(nn.Module): + def __init__( + self, + config: GraniteConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.layers = nn.ModuleList( + [ + GraniteDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + hidden_states *= self.config.embedding_multiplier + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class GraniteForCausalLM(nn.Module): + def __init__( + self, + config: GraniteConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = GraniteModel(config, quant_config=quant_config) + # If tie_word_embeddings == True, then input and output embeddings are + # the same tensor. Enforce during object creation so that weights will + # load correctly even if the LM head weights don't have a separate entry + # in the state dict. + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + if self.config.tie_word_embeddings: + self.lm_head.tie_weights(self.model.embed_tokens) + + # Granite logit scaling factors are applied via division, but + # LogitsProcessor expects a multiplicative factor. + if hasattr(config, "logits_scaling"): + logit_scale = 1.0 / config.logits_scaling + else: + logit_scale = None + self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + if not get_embedding: + logits_processor_output: LogitsProcessorOutput = self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + return logits_processor_output + else: + return self.pooler(hidden_states, forward_batch) + + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "o_proj", "qkv_proj"]: + return self.config.hidden_size, self.config.hidden_size + elif module_name in ["kv_proj"]: + return self.config.hidden_size, self.config.hidden_size // ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + + def get_module_name_from_weight_name(self, name): + for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + if "lm_head.weight" in name and self.config.tie_word_embeddings: + # Input and output embeddings are tied, so the output embeddings + # may not be present in the checkpoint. We assume that the input + # embeddings are always present in the checkpoint. + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # This block only runs if the preceding for loop doesn't find + # a match for `name` in `stacked_params_mapping`. + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip loading kv_scale from ckpts towards new design. + if name.endswith(".kv_scale") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def get_weights_by_name( + self, name: str, truncate_size: int = 100, tp_size: int = 1 + ) -> Optional[torch.Tensor]: + """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. + + Only used for unit test with an unoptimized performance. + For optimized performance, please use torch.save and torch.load. + """ + try: + if name == "lm_head.weight" and self.config.tie_word_embeddings: + logger.info( + "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight." + ) + return ( + self.model.embed_tokens.weight.cpu() + .to(torch.float32) + .numpy() + .tolist()[:truncate_size] + ) + + mapped_name = name + mapped_shard_id = None + for param_name, weight_name, shard_id in self.stacked_params_mapping: + if weight_name in name: + mapped_name = name.replace(weight_name, param_name) + mapped_shard_id = shard_id + break + params_dict = dict(self.named_parameters()) + param = params_dict[mapped_name] + if mapped_shard_id is not None: + if mapped_shard_id in ["q", "k", "v"]: + num_heads = self.config.num_attention_heads // tp_size + num_kv_heads = self.config.num_key_value_heads // tp_size + head_dim = ( + self.config.hidden_size // self.config.num_attention_heads + ) + if mapped_shard_id == "q": + offset = 0 + size = num_heads * head_dim + elif mapped_shard_id == "k": + offset = num_heads * head_dim + size = num_kv_heads * head_dim + elif mapped_shard_id == "v": + offset = (num_heads + num_kv_heads) * head_dim + size = num_kv_heads * head_dim + weight = param.data.narrow(0, offset, size) + elif mapped_shard_id in [0, 1]: + intermediate_size = self.config.intermediate_size + slice_size = intermediate_size // tp_size + if mapped_shard_id == 0: # gate_proj + offset = 0 + size = slice_size + elif mapped_shard_id == 1: # up_proj + offset = slice_size + size = slice_size + + weight = param.data.narrow(0, offset, size) + else: + weight = param.data + else: + weight = param.data + if tp_size > 1 and ("o_proj" in name or "down_proj" in name): + gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)] + torch.distributed.all_gather(gathered_weights, weight) + weight = torch.cat(gathered_weights, dim=1) + return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] + + except Exception: + logger.error( + f"Error getting weights by name {name} in GraniteForCausalLM: {get_exception_traceback()}" + ) + return None + + +EntryClass = [GraniteForCausalLM] diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index d9f1795341c..fd27d5c07b6 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -57,6 +57,7 @@ class ModelCase: ModelCase("openai-community/gpt2"), ModelCase("microsoft/Phi-3-small-8k-instruct"), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), + ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), ] TORCH_DTYPES = [torch.float16]