Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc][LoRA] Fix LoRA weight mapper #11495

Merged
merged 6 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_load_checkpoints(
embedding_padding_modules=embed_padding_modules)


def test_lora_weights_mapping(baichuan_lora_files, ):
def test_lora_weights_mapping(baichuan_lora_files):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
Expand All @@ -86,10 +86,14 @@ def test_lora_weights_mapping(baichuan_lora_files, ):
else:
expected_lora_modules.append(module)

hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"model.": "language_model.model.",
}, )

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",
},
orig_to_new_substr={
".layers.": ".baichuan_layers.",
},
)
lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
Expand All @@ -101,3 +105,4 @@ def test_lora_weights_mapping(baichuan_lora_files, ):
)
for name in lora_model.loras:
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
assert ".baichuan_layers." in name
6 changes: 5 additions & 1 deletion tests/lora/test_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
"A red stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
]

Expand Down Expand Up @@ -76,3 +76,7 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])

output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])
3 changes: 2 additions & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def from_local_checkpoint(
with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module_name)
Expand Down
34 changes: 12 additions & 22 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os
import re
from typing import List, Optional, Set, Tuple, Type, Union
Expand Down Expand Up @@ -32,7 +31,6 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper
from vllm.utils import print_warning_once

logger = init_logger(__name__)

Expand Down Expand Up @@ -112,36 +110,28 @@ def parse_fine_tuned_lora_name(
is_bias whether the tensor is lora bias.
"""

w_mapper = None
if weights_mapper:
w_mapper = copy.deepcopy(weights_mapper)
# TODO: Currently only supports mapping for prefix, mapping for
# substr and subfix will be supported in the future.
for attr, mapping in [
("orig_to_new_substr", w_mapper.orig_to_new_substr),
("orig_to_new_suffix", w_mapper.orig_to_new_suffix),
]:
if mapping:
print_warning_once(
f"vLLM currently does not support mapping of LoRA weights "
f"for {mapping}.")
setattr(w_mapper, attr, {})

mapper = (lambda name: w_mapper._map_name(name)
if w_mapper is not None else name)
# LoRA weight qualified name always starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if "base_model.model." in name:
name = name.replace("base_model.model.", "")
name = weights_mapper._map_name(name) if weights_mapper else name
# recover the prefix `base_model.model.`
name = "base_model.model." + name

parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"):
new_name = ".".join(parts[2:-2])
return mapper(new_name), parts[-2] == "lora_A", False
return new_name, parts[-2] == "lora_A", False

if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
new_name = ".".join(parts[2:-1])
return mapper(new_name), parts[-1] == "lora_embedding_A", False
return new_name, parts[-1] == "lora_embedding_A", False

if parts[-1] == "bias":
new_name = ".".join(parts[2:-2])
return mapper(new_name), False, True
return new_name, False, True

raise ValueError(f"{name} is unsupported LoRA weight")

Expand Down
2 changes: 2 additions & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
packed_modules_mapping[module])
else:
expected_lora_modules.append(module)

expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path)

# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
Expand Down
Loading