Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate llama_classification to use the /classify interface #2417

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions python/sglang/srt/models/llama_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import nn
from transformers import LlamaConfig

from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand All @@ -40,7 +40,7 @@ def __init__(
self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size, bias=False
)
self.eos_token_id = config.eos_token_id
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)

@torch.no_grad()
def forward(
Expand All @@ -49,28 +49,17 @@ def forward(
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
is_eos_token = input_ids == self.eos_token_id
hidden_states = hidden_states[is_eos_token]
scores = self.classification_head(hidden_states)

if scores.shape[0] != forward_batch.batch_size:
print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones(
(forward_batch.batch_size, self.config.classification_out_size)
).to(input_ids.device)
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."

logits_output = LogitsProcessorOutput(
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
input_token_logprobs=torch.ones_like(input_ids),
input_top_logprobs=None,
output_top_logprobs=None,
)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.classification_head(last_token_hidden)

return logits_output
return EmbeddingPoolerOutput(scores)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
Expand Down
22 changes: 19 additions & 3 deletions scripts/deprecated/test_httpserver_classify.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Usage:
python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache

python3 test_httpserver_classify.py
"""
Expand All @@ -11,7 +11,7 @@
import requests


def get_logits(url, prompt):
def get_logits_deprecated(url: str, prompt: str):
response = requests.post(
url + "/generate",
json={
Expand All @@ -25,7 +25,7 @@ def get_logits(url, prompt):
return response.json()["meta_info"]["normalized_prompt_logprob"]


def get_logits_batch(url, prompts):
def get_logits_batch_deprecated(url: str, prompts: list[str]):
response = requests.post(
url + "/generate",
json={
Expand All @@ -46,6 +46,22 @@ def get_logits_batch(url, prompts):
return logits


def get_logits(url: str, prompt: str):
response = requests.post(
url + "/classify",
json={"text": prompt},
)
return response.json()["embedding"]


def get_logits_batch(url: str, prompts: list[str]):
response = requests.post(
url + "/classify",
json={"text": prompts},
)
return np.array([x["embedding"] for x in response.json()])


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
Expand Down
Loading