Skip to content

Commit

Permalink
[Bugfix] Fix RobertaModel loading (#11940)
Browse files Browse the repository at this point in the history
Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
NickLucche authored Jan 11, 2025
1 parent a991f7d commit d697dc0
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 12 deletions.
27 changes: 26 additions & 1 deletion tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.pooler import CLSPool, PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform
Expand Down Expand Up @@ -92,3 +92,28 @@ def test_roberta_model_loading_with_params(vllm_runner):

# assert output
assert output


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_facebook_roberta_model_loading_with_params(vllm_runner):
"""
Test loading roberta-base model with no lm_head.
"""
model_name = "FacebookAI/roberta-base"
with vllm_runner(model_name=model_name,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_tokenizer = model.model.llm_engine.tokenizer
assert model_tokenizer.tokenizer_id == model_name

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert not hasattr(model, "lm_head")
assert isinstance(model, RobertaEmbeddingModel)
assert isinstance(model._pooler, CLSPool)

assert output
1 change: 1 addition & 0 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
],
)
@pytest.mark.parametrize("dtype", ["half"])
Expand Down
51 changes: 40 additions & 11 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -20,6 +21,30 @@
from .interfaces import SupportsCrossEncoding


def roberta_task_weights_filter(
all_weights: Iterable[Tuple[str, torch.Tensor]]
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
torch.Tensor]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1, all_weights2 = itertools.tee(all_weights)

def encoder_decoder_weights():
for name, weight in all_weights1:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)

return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
if not n.startswith("roberta."))


class RobertaEmbedding(nn.Module):

def __init__(self, config: RobertaConfig):
Expand Down Expand Up @@ -152,6 +177,18 @@ def _build_model(self,
prefix=prefix,
embedding_class=RobertaEmbedding)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
bert_weights, task_weights = roberta_task_weights_filter(weights)
loaded = self.model.load_weights(bert_weights)
if not len(loaded):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
# which use the same architecture, but have no "roberta" prefix.
loaded = self.model.load_weights(task_weights)
assert len(loaded), "Unable to load RobertaEmbeddingModel"


class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.
Expand Down Expand Up @@ -181,20 +218,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

self_weights = []

def weight_filter():
for name, weight in weights:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)
else:
self_weights.append((name, weight))

self.roberta.load_weights(weight_filter())
bert_weights, task_weights = roberta_task_weights_filter(weights)
self.roberta.load_weights(bert_weights)

params_dict = dict(self.named_parameters())

for name, loaded_weight in self_weights:
for name, loaded_weight in task_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
Expand Down

0 comments on commit d697dc0

Please sign in to comment.