Skip to content

Commit

Permalink
[Doc] Basic guide for writing unit tests for new models (vllm-project…
Browse files Browse the repository at this point in the history
…#11951)

Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
  • Loading branch information
DarkLight1337 authored and hmellor committed Jan 12, 2025
1 parent 7e41adf commit 53f69b1
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/contributing/model/basic.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(new-model-basic)=

# Basic Implementation
# Implementing a Basic Model

This guide walks you through the steps to implement a basic vLLM model.

Expand Down
1 change: 1 addition & 0 deletions docs/source/contributing/model/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This section provides more information on how to integrate a [PyTorch](https://p
basic
registration
tests
multimodal
```

Expand Down
3 changes: 1 addition & 2 deletions docs/source/contributing/model/registration.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(new-model-registration)=

# Model Registration
# Registering a Model to vLLM

vLLM relies on a model registry to determine how to run each model.
A list of pre-registered architectures can be found [here](#supported-models).
Expand All @@ -15,7 +15,6 @@ This gives you the ability to modify the codebase and test your model.

After you have implemented your model (see [tutorial](#new-model-basic)), put it into the <gh-dir:vllm/model_executor/models> directory.
Then, add your model class to `_VLLM_MODELS` in <gh-file:vllm/model_executor/models/registry.py> so that it is automatically registered upon importing vLLM.
You should also include an example HuggingFace repository for this model in <gh-file:tests/models/registry.py> to run the unit tests.
Finally, update our [list of supported models](#supported-models) to promote your model!

```{important}
Expand Down
63 changes: 63 additions & 0 deletions docs/source/contributing/model/tests.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
(new-model-tests)=

# Writing Unit Tests

This page explains how to write unit tests to verify the implementation of your model.

## Required Tests

These tests are necessary to get your PR merged into vLLM library.
Without them, the CI for your PR will fail.

### Model loading

Include an example HuggingFace repository for your model in <gh-file:tests/models/registry.py>.
This enables a unit test that loads dummy weights to ensure that the model can be initialized in vLLM.

```{important}
The list of models in each section should be maintained in alphabetical order.
```

```{tip}
If your model requires a development version of HF Transformers, you can set
`min_transformers_version` to skip the test in CI until the model is released.
```

## Optional Tests

These tests are optional to get your PR merged into vLLM library.
Passing these tests provides more confidence that your implementation is correct, and helps avoid future regressions.

### Model correctness

These tests compare the model outputs of vLLM against [HF Transformers](https://github.com/huggingface/transformers). You can add new tests under the subdirectories of <gh-dir:tests/models>.

#### Generative models

For [generative models](#generative-models), there are two levels of correctness tests, as defined in <gh-file:tests/models/utils.py>:

- Exact correctness (`check_outputs_equal`): The text outputted by vLLM should exactly match the text outputted by HF.
- Logprobs similarity (`check_logprobs_close`): The logprobs outputted by vLLM should be in the top-k logprobs outputted by HF, and vice versa.

#### Pooling models

For [pooling models](#pooling-models), we simply check the cosine similarity, as defined in <gh-file:tests/models/embedding/utils.py>.

(mm-processing-tests)=

### Multi-modal processing

#### Common tests

Adding your model to <gh-file:tests/models/multimodal/processing/test_common.py> verifies that the following input combinations result in the same outputs:

- Text + multi-modal data
- Tokens + multi-modal data
- Text + cached multi-modal data
- Tokens + cached multi-modal data

#### Model-specific tests

You can add a new file under <gh-dir:tests/models/multimodal/processing> to run tests that only apply to your model.

For example, if the HF processor for your model accepts user-specified keyword arguments, you can verify that the keyword arguments are being applied correctly, such as in <gh-file:tests/models/multimodal/processing/test_phi3v.py>.
5 changes: 5 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class _HfExamplesInfo:
for speculative decoding.
"""

min_transformers_version: Optional[str] = None
"""
The minimum version of HF Transformers that is required to run this model.
"""

is_available_online: bool = True
"""
Set this to ``False`` if the name of this architecture no longer exists on
Expand Down
10 changes: 10 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import patch

import pytest
from packaging.version import Version
from transformers import PretrainedConfig
from transformers import __version__ as TRANSFORMERS_VERSION

from vllm import LLM

Expand All @@ -13,6 +15,14 @@ def test_can_initialize(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if not model_info.is_available_online:
pytest.skip("Model is not available online")
if model_info.min_transformers_version is not None:
current_version = TRANSFORMERS_VERSION
required_version = model_info.min_transformers_version
if Version(current_version) < Version(required_version):
pytest.skip(
f"You have `transformers=={current_version}` installed, but "
f"`transformers>={required_version}` is required to run this "
"model")

# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
Expand Down

0 comments on commit 53f69b1

Please sign in to comment.