Skip to content

Commit

Permalink
[Misc] Optimize Qwen2-VL LoRA test (#11663)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Jan 1, 2025
1 parent 365801f commit 11d8a09
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
5 changes: 2 additions & 3 deletions tests/lora/test_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct"
MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"

PROMPT_TEMPLATE = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
Expand Down Expand Up @@ -49,10 +49,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Generated text: {generated_text!r}")
return generated_texts


Expand Down
20 changes: 19 additions & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
GPTQMarlinConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
Expand Down Expand Up @@ -926,15 +927,23 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
}

# LoRA specific attributes
# TODO Support LoRA for the visual encoder in the future.
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
# vision tower
"qkv",
"attn.proj", # Distinguish patch_embed.proj
"fc1",
"fc2",
# projector
"mlp.0",
"mlp.2"
]
embedding_modules = {}
embedding_padding_modules = []

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
Expand Down Expand Up @@ -1231,3 +1240,12 @@ def load_weights(self, weights: Iterable[Tuple[str,

loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="visual.",
tower_model="visual.merger.")

0 comments on commit 11d8a09

Please sign in to comment.