Skip to content

Commit

Permalink
Adds method to read the pooling types from model's files (vllm-projec…
Browse files Browse the repository at this point in the history
…t#9506)

Signed-off-by: Flavia Beo <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Co-authored-by: Max de Bayser <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
  • Loading branch information
2 people authored and JC1DA committed Nov 11, 2024
1 parent 303dde8 commit eb95aca
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 25 deletions.
4 changes: 2 additions & 2 deletions examples/fp8/quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def calibrate_loop():

def main(args):
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for inference.")
raise OSError("GPU is required for inference.")

random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
Expand Down Expand Up @@ -314,7 +314,7 @@ def main(args):

# Workaround for wo quantization
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
with open(f"{export_path}/config.json", 'r') as f:
with open(f"{export_path}/config.json") as f:
tensorrt_llm_config = json.load(f)
if args.qformat == "int8_wo":
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
Expand Down
7 changes: 7 additions & 0 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):
assert args.limit_mm_per_prompt == expected


def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args(["--pooling-type=MEAN"])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.pooling_type == 'MEAN'


@pytest.mark.parametrize(
("arg"),
[
Expand Down
50 changes: 50 additions & 0 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

import pytest

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.platforms import current_platform

MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_config = model.model.llm_engine.model_config

model_tokenizer = model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
assert model_config.encoder_config["do_lower_case"]

# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
assert model_config.pooler_config.pooling_norm

# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5"
assert model_tokenizer.tokenizer_config["do_lower_case"]
assert model_tokenizer.tokenizer.model_max_length == 512

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize
# assert output
assert output
72 changes: 72 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest

from vllm.config import ModelConfig
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform


@pytest.mark.parametrize(("model_id", "expected_task"), [
Expand Down Expand Up @@ -102,6 +104,76 @@ def test_get_sliding_window():
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)

minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type=None,
pooling_norm=None,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)

assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None)

minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type='CLS',
pooling_norm=True,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)

assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
bge_model_config = ModelConfig(
model="BAAI/bge-base-en-v1.5",
task="auto",
tokenizer="BAAI/bge-base-en-v1.5",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)

bert_bge_model_config = bge_model_config._get_encoder_config()

assert bert_bge_model_config["max_seq_length"] == 512
assert bert_bge_model_config["do_lower_case"]


def test_rope_customization():
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
Expand Down
14 changes: 8 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest
import requests
import torch
import torch.nn.functional as F
from openai.types.completion import Completion
from typing_extensions import ParamSpec

Expand Down Expand Up @@ -515,13 +516,14 @@ def compare_all_settings(model: str,
ref_result = copy.deepcopy(ref_result)
compare_result = copy.deepcopy(compare_result)
if "embedding" in ref_result and method == "encode":
ref_embedding = torch.tensor(ref_result["embedding"])
compare_embedding = torch.tensor(
compare_result["embedding"])
mse = ((ref_embedding - compare_embedding)**2).mean()
assert mse < 1e-6, (
sim = F.cosine_similarity(
torch.tensor(ref_result["embedding"]),
torch.tensor(compare_result["embedding"]),
dim=0,
)
assert sim >= 0.999, (
f"Embedding for {model=} are not the same.\n"
f"mse={mse}\n")
f"cosine_similarity={sim}\n")
del ref_result["embedding"]
del compare_result["embedding"]
assert ref_result == compare_result, (
Expand Down
28 changes: 23 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config,
is_encoder_decoder, uses_mrope)
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once)

Expand Down Expand Up @@ -197,6 +197,7 @@ def __init__(
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Expand Down Expand Up @@ -229,7 +230,8 @@ def __init__(
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len)
spec_target_max_model_len=spec_target_max_model_len,
encoder_config=self.encoder_config)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
Expand Down Expand Up @@ -273,6 +275,10 @@ def _init_multimodal_config(

return None

def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)

def _init_pooler_config(
self,
pooling_type: Optional[str] = None,
Expand All @@ -282,6 +288,14 @@ def _init_pooler_config(
pooling_returned_token_ids: Optional[List[int]] = None
) -> Optional["PoolerConfig"]:
if self.task == "embedding":
pooling_config = get_pooling_config(self.model, self.revision)
if pooling_config is not None:
# override if user does not
# specifies pooling_type and/or pooling_norm
if pooling_type is None:
pooling_type = pooling_config["pooling_type"]
if pooling_norm is None:
pooling_norm = pooling_config["normalize"]
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
Expand Down Expand Up @@ -1795,6 +1809,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
spec_target_max_model_len: Optional[int] = None,
encoder_config: Optional[Any] = None,
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
Expand Down Expand Up @@ -1877,6 +1892,9 @@ def _get_and_verify_max_len(
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

if encoder_config and "max_seq_length" in encoder_config:
derived_max_model_len = encoder_config["max_seq_length"]

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
if max_model_len is None:
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
VllmConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
Expand Down Expand Up @@ -863,7 +864,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

parser.add_argument(
'--pooling-type',
choices=['LAST', 'ALL', 'CLS', 'STEP'],
choices=[pt.name for pt in PoolingType],
default=None,
help='Used to configure the pooling method in the embedding model.'
)
Expand Down
14 changes: 13 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class PoolingType(IntEnum):
ALL = 1
CLS = 2
STEP = 3
MEAN = 4


class Pooler(nn.Module):
Expand All @@ -27,7 +28,7 @@ class Pooler(nn.Module):
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, ALL, CLS).
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""

Expand Down Expand Up @@ -97,6 +98,17 @@ def forward(
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
elif self.pooling_type == PoolingType.MEAN:
# Calculate mean pooling
cumsum = torch.cumsum(hidden_states, dim=0)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
pooled_data = (
cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
elif self.pooling_type == PoolingType.STEP:
if self.returned_token_ids is not None and len(
self.returned_token_ids) > 0:
Expand Down
Loading

0 comments on commit eb95aca

Please sign in to comment.