Skip to content

Commit

Permalink
Add support for IBM Granite 3.x models (#2437)
Browse files Browse the repository at this point in the history
  • Loading branch information
frreiss authored Dec 11, 2024
1 parent f854829 commit 993956c
Show file tree
Hide file tree
Showing 5 changed files with 562 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/references/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- SmolLM
- GLM-4
- Phi-3-Small
- IBM Granite 3

## Embedding Models

Expand Down
32 changes: 32 additions & 0 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,28 @@ def get_chat_template_by_model_path(model_path):
)
)

register_chat_template(
ChatTemplate(
name="granite-3-instruct",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_of_role|>system<|end_of_role|>",
"<|end_of_text|>",
),
"user": (
"<|start_of_role|>user<|end_of_role|>",
"<|end_of_text|>",
),
"assistant": (
"<|start_of_role|>assistant<|end_of_role|>",
"<|end_of_text|>",
),
},
stop_str=("<|end_of_text|>",),
)
)


@register_chat_template_matching_function
def match_dbrx(model_path: str):
Expand Down Expand Up @@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str):
return get_chat_template("c4ai-command-r")


@register_chat_template_matching_function
def match_granite_instruct(model_path: str):
model_path = model_path.lower()
# When future versions of Granite are released, this code may
# need to be updated. For now, assume that the Granite 3.0
# template works across the board.
if "granite" in model_path and "instruct" in model_path:
return get_chat_template("granite-3-instruct")


if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
Expand Down
12 changes: 11 additions & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):


class LogitsProcessor(nn.Module):
def __init__(self, config, skip_all_gather: bool = False):
def __init__(
self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None
):
super().__init__()
self.config = config
self.logit_scale = logit_scale
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
Expand Down Expand Up @@ -240,6 +243,9 @@ def forward(
all_logits = self._get_logits(states, lm_head)
if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits)

# The LM head's weights may be zero-padded for parallelism. Remove any
# extra logits that this padding may have produced.
all_logits = all_logits[:, : self.config.vocab_size].float()

if hasattr(self.config, "final_logit_softcapping"):
Expand Down Expand Up @@ -302,6 +308,10 @@ def _get_logits(
else:
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)

# Optional scaling factor, backported from vLLM 0.4
if self.logit_scale is not None:
logits.mul_(self.logit_scale) # In-place multiply
return logits


Expand Down
Loading

0 comments on commit 993956c

Please sign in to comment.