Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix RobertaModel loading #11940

Merged
merged 2 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.pooler import CLSPool, PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform
Expand Down Expand Up @@ -92,3 +92,28 @@ def test_roberta_model_loading_with_params(vllm_runner):

# assert output
assert output


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_facebook_roberta_model_loading_with_params(vllm_runner):
"""
Test loading roberta-base model with no lm_head.
"""
model_name = "FacebookAI/roberta-base"
with vllm_runner(model_name=model_name,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_tokenizer = model.model.llm_engine.tokenizer
assert model_tokenizer.tokenizer_id == model_name

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert not hasattr(model, "lm_head")
assert isinstance(model, RobertaEmbeddingModel)
assert isinstance(model._pooler, CLSPool)

assert output
1 change: 1 addition & 0 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
],
)
@pytest.mark.parametrize("dtype", ["half"])
Expand Down
51 changes: 40 additions & 11 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -20,6 +21,30 @@
from .interfaces import SupportsCrossEncoding


def roberta_task_weights_filter(
all_weights: Iterable[Tuple[str, torch.Tensor]]
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
torch.Tensor]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1, all_weights2 = itertools.tee(all_weights)

def encoder_decoder_weights():
for name, weight in all_weights1:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)

return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
if not n.startswith("roberta."))


class RobertaEmbedding(nn.Module):

def __init__(self, config: RobertaConfig):
Expand Down Expand Up @@ -152,6 +177,18 @@ def _build_model(self,
prefix=prefix,
embedding_class=RobertaEmbedding)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
bert_weights, task_weights = roberta_task_weights_filter(weights)
loaded = self.model.load_weights(bert_weights)
if not len(loaded):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
# which use the same architecture, but have no "roberta" prefix.
loaded = self.model.load_weights(task_weights)
assert len(loaded), "Unable to load RobertaEmbeddingModel"


class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.
Expand Down Expand Up @@ -181,20 +218,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

self_weights = []

def weight_filter():
for name, weight in weights:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)
else:
self_weights.append((name, weight))

self.roberta.load_weights(weight_filter())
bert_weights, task_weights = roberta_task_weights_filter(weights)
self.roberta.load_weights(bert_weights)

params_dict = dict(self.named_parameters())

for name, loaded_weight in self_weights:
for name, loaded_weight in task_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
Expand Down
Loading