forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds method to read the pooling types from model's files (vllm-projec…
…t#9506) Signed-off-by: Flavia Beo <[email protected]> Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: Max de Bayser <[email protected]> Signed-off-by: Loc Huynh <[email protected]>
- Loading branch information
Showing
10 changed files
with
342 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from vllm.model_executor.layers.pooler import PoolingType | ||
from vllm.model_executor.models.bert import BertEmbeddingModel | ||
from vllm.platforms import current_platform | ||
|
||
MAX_MODEL_LEN = 128 | ||
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") | ||
REVISION = os.environ.get("REVISION", "main") | ||
|
||
|
||
@pytest.mark.skipif(current_platform.is_rocm(), | ||
reason="Xformers backend is not supported on ROCm.") | ||
def test_model_loading_with_params(vllm_runner): | ||
""" | ||
Test parameter weight loading with tp>1. | ||
""" | ||
with vllm_runner(model_name=MODEL_NAME, | ||
revision=REVISION, | ||
dtype="float16", | ||
max_model_len=MAX_MODEL_LEN) as model: | ||
output = model.encode("Write a short story about a robot that" | ||
" dreams for the first time.\n") | ||
|
||
model_config = model.model.llm_engine.model_config | ||
|
||
model_tokenizer = model.model.llm_engine.tokenizer | ||
|
||
# asserts on the bert model config file | ||
assert model_config.encoder_config["max_seq_length"] == 512 | ||
assert model_config.encoder_config["do_lower_case"] | ||
|
||
# asserts on the pooling config files | ||
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name | ||
assert model_config.pooler_config.pooling_norm | ||
|
||
# asserts on the tokenizer loaded | ||
assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5" | ||
assert model_tokenizer.tokenizer_config["do_lower_case"] | ||
assert model_tokenizer.tokenizer.model_max_length == 512 | ||
|
||
model = model.model.llm_engine.model_executor\ | ||
.driver_worker.model_runner.model | ||
assert isinstance(model, BertEmbeddingModel) | ||
assert model._pooler.pooling_type == PoolingType.CLS | ||
assert model._pooler.normalize | ||
# assert output | ||
assert output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.