Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: load_weights() does not work for RobertaModel embeddings since weights start with "roberta." #11821

Closed
1 task done
chmeyers opened this issue Jan 8, 2025 · 5 comments · Fixed by #11940
Closed
1 task done
Labels
bug Something isn't working

Comments

@chmeyers
Copy link

chmeyers commented Jan 8, 2025

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

Model Input Dumps

No response

🐛 Describe the bug

The RobertaEmbeddingModel here:

return BertModel(vllm_config=vllm_config,

just uses the base BertModel() class, so when model.load_weights() is called the param names don't match, leading to this stack trace:

File "/home/ray/anaconda3/lib/python3.10/site-packages/vllm/model_executor/models/bert.py", line 448, in load_weights
self.model.load_weights(weights)
File "/home/ray/anaconda3/lib/python3.10/site-packages/vllm/model_executor/models/bert.py", line 394, in load_weights
param = params_dict[name]
KeyError: 'roberta.embeddings.LayerNorm.weight'

I think it should be renaming the weights to remove the "roberta." bit similar to

def weight_filter():

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@chmeyers chmeyers added the bug Something isn't working label Jan 8, 2025
@DarkLight1337
Copy link
Member

cc @maxdebayser

@chmeyers
Copy link
Author

chmeyers commented Jan 8, 2025

Script for Repro:

import os

import torch
from torch import nn
from huggingface_hub import snapshot_download
from transformers import AutoModel
from safetensors import safe_open


from vllm import EngineArgs, LLMEngine
from vllm.config import (LoadConfig, ModelConfig, VllmConfig)
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader import BaseModelLoader

DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
model_dir = snapshot_download("FacebookAI/roberta-base", allow_patterns=DOWNLOAD_PATTERN)

class WeightsLoader(BaseModelLoader):
    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

    def download_model(self, model_config: ModelConfig) -> None:
        pass

    def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(vllm_config=vllm_config)
            safetensorfile = os.path.join(model_dir, "model.safetensors")
            with safe_open(safetensorfile, framework="pt", device="cpu") as f:
                for name in f.keys():
                    buf = f.get_tensor(name)
                    print(name)
                    model.load_weights([(name, buf)])

        return model.eval()

engine_args = EngineArgs(model=model_dir, load_format=WeightsLoader, device="cpu")
LLMEngine.from_engine_args(engine_args)

@chmeyers
Copy link
Author

chmeyers commented Jan 8, 2025

Actually, after wondering how it was working on the default load path, I realized that it wasn't working there, either.

So a simpler repro:

from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine

DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
model_dir = snapshot_download("FacebookAI/roberta-base", allow_patterns=DOWNLOAD_PATTERN)

engine_args = EngineArgs(model=model_dir, device="cpu")
LLMEngine.from_engine_args(engine_args)

@noooop
Copy link
Contributor

noooop commented Jan 9, 2025

@maxdebayser

  • FacebookAI/roberta-base without using roberta prefix

  • BAAI/bge-m3 using roberta prefix

very dirt

refer to

@NickLucche
Copy link
Contributor

I can look into this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants