From b8fdc150463db62a6cc1b02a9c3fe6e48ee81e0e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 23 Dec 2024 09:52:12 +0000 Subject: [PATCH 1/2] Done Signed-off-by: Jee Jee Li --- tests/lora/test_lora_checkpoints.py | 30 ++++++++++++++++++++++++++ vllm/lora/models.py | 9 +++++--- vllm/lora/utils.py | 25 +++++++++++++++++---- vllm/lora/worker_manager.py | 11 +++++++++- vllm/model_executor/models/qwen2_vl.py | 12 +++++------ 5 files changed, 73 insertions(+), 14 deletions(-) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 9a529e27b4cd8..9842203eb15e0 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -4,6 +4,7 @@ from vllm.lora.models import LoRAModel from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM +from vllm.model_executor.models.utils import WeightsMapper lora_lst = [ "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" @@ -71,3 +72,32 @@ def test_load_checkpoints( device="cpu", embedding_modules=embedding_modules, embedding_padding_modules=embed_padding_modules) + + +def test_lora_weights_mapping(baichuan_lora_files, ): + supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules + packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping + embedding_modules = BaiChuanBaseForCausalLM.embedding_modules + embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules + expected_lora_modules: List[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "model.": "language_model.model.", + }, ) + + lora_model = LoRAModel.from_local_checkpoint( + baichuan_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + for name in lora_model.loras: + assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."]) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 70806a77b9fff..f50db8e3b8e10 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -28,7 +28,7 @@ parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.utils import PPMissingLayer +from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -113,13 +113,14 @@ def from_lora_tensors( target_embedding_padding: Optional[int] = None, embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: Dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( - tensor_name) + tensor_name, weights_mapper) if module_name not in loras: lora_embeddings_tensor = None if embeddings: @@ -187,6 +188,7 @@ def from_local_checkpoint( target_embedding_padding: Optional[int] = None, embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. @@ -289,7 +291,8 @@ def from_local_checkpoint( embeddings=embeddings, target_embedding_padding=target_embedding_padding, embedding_modules=embedding_modules, - embedding_padding_modules=embedding_padding_modules) + embedding_padding_modules=embedding_padding_modules, + weights_mapper=weights_mapper) class LoRAModelManager(AdapterModelManager): diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 5876494ce2824..9ca676332e3f0 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -30,6 +30,7 @@ # yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -91,28 +92,44 @@ def replace_submodule(model: nn.Module, module_name: str, return new_module -def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]: +def parse_fine_tuned_lora_name( + name: str, + weights_mapper: Optional[WeightsMapper] = None +) -> Tuple[str, bool, bool]: """Parse the name of lora weights. args: name: the name of the fine-tuned LoRA, e.g. base_model.model.dense1.weight + weights_mapper: maps the name of weight, e.g. + `model.` -> `language_model.model.`, return: Tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, is_lora_a whether the tensor is lora_a or lora_b. is_bias whether the tensor is lora bias. """ + # TODO: Currently only supports mapping for prefix, mapping for substr and + # subfix will be supported in the future. + if weights_mapper is not None: + weights_mapper.orig_to_new_substr = {} + weights_mapper.orig_to_new_suffix = {} + + mapper = (lambda name: weights_mapper._map_name(name) + if weights_mapper is not None else name) parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): - return ".".join(parts[2:-2]), parts[-2] == "lora_A", False + new_name = ".".join(parts[2:-2]) + return mapper(new_name), parts[-2] == "lora_A", False if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False + new_name = ".".join(parts[2:-1]) + return mapper(new_name), parts[-1] == "lora_embedding_A", False if parts[-1] == "bias": - return ".".join(parts[2:-2]), False, True + new_name = ".".join(parts[2:-2]) + return mapper(new_name), False, True raise ValueError(f"{name} is unsupported LoRA weight") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 93a5e27621912..ef8cc5886103e 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -92,6 +92,14 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: else: expected_lora_modules.append(module) lora_path = get_adapter_absolute_path(lora_request.lora_path) + + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + hf_to_vllm_mapper = None + if (hasattr(model, "hf_to_vllm_mapper") + and model.hf_to_vllm_mapper is not None): + hf_to_vllm_mapper = model.hf_to_vllm_mapper + lora = self._lora_model_cls.from_local_checkpoint( lora_path, expected_lora_modules, @@ -103,7 +111,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, - ) + weights_mapper=hf_to_vllm_mapper) + except Exception as e: raise RuntimeError(f"Loading lora {lora_path} failed") from e if lora.rank > self.lora_config.max_lora_rank: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b38ea923f0bf1..fb97eb1916002 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -901,6 +901,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ] embedding_modules = {} embedding_padding_modules = [] + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1190,11 +1195,6 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "lm_head.": "language_model.lm_head.", - "model.": "language_model.model.", - }) loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) From e5de5e9e68dd584408212f7ad32dfb1b67517056 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 23 Dec 2024 10:19:05 +0000 Subject: [PATCH 2/2] Done Signed-off-by: Jee Jee Li --- .../my_gemma_embedding.py | 5 +- vllm/model_executor/models/aria.py | 20 +++---- vllm/model_executor/models/bert.py | 4 +- vllm/model_executor/models/molmo.py | 58 ++++++++++--------- vllm/model_executor/models/phi3v.py | 16 ++--- vllm/model_executor/models/qwen2.py | 5 +- vllm/model_executor/models/telechat2.py | 27 ++++----- vllm/model_executor/models/ultravox.py | 7 ++- 8 files changed, 74 insertions(+), 68 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index d676eacffb056..5e7d7d1877e61 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -13,6 +13,7 @@ class MyGemma2Embedding(nn.Module): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -62,8 +63,8 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = hf_to_vllm_mapper.apply(weights) + + weights = self.hf_to_vllm_mapper.apply(weights) weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) return self.model.load_weights(weights) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index dd4b0c75cb84d..9437ad9688422 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -521,6 +521,15 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): This model combines a vision tower, a multi-modal projector, and a language model to perform tasks that involve both image and text inputs. """ + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "language_model", + "language_model.lm_head": "lm_head", + }, + orig_to_new_suffix={ + "router.weight": "router_weight", + }, + ) def __init__( self, @@ -662,15 +671,6 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "language_model.model": "language_model", - "language_model.lm_head": "lm_head", - }, - orig_to_new_suffix={ - "router.weight": "router_weight", - }, - ) loader = AutoWeightsLoader(self) - loader.load_weights(weights, mapper=hf_to_vllm_mapper) + loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 053d838432885..c1d47b1bc9bcd 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -409,6 +409,7 @@ class BertEmbeddingModel(nn.Module): model: An instance of BertModel used for forward operations. _pooler: An instance of Pooler used for pooling operations. """ + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -441,8 +442,7 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = hf_to_vllm_mapper.apply(weights) + weights = self.hf_to_vllm_mapper.apply(weights) weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) self.model.load_weights(weights) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 9f744b6918818..63a25137f8aa9 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1123,6 +1123,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + # vision backbone mapping + "image_projector.w1.": "image_projector.gate_proj.", + "image_projector.w3.": "image_projector.up_proj.", + "image_projector.w2.": "image_projector.down_proj.", + # language backbone mapping + "att_proj": "self_attn.qkv_proj", + "attn_out": "self_attn.o_proj", + "q_norm": "self_attn.q_norm", + "k_norm": "self_attn.k_norm", + "ff_proj": "mlp.gate_up_proj", + "ff_out": "mlp.down_proj", + "attn_norm": "input_layernorm", + "ff_norm": "post_attention_layernorm", + }, + orig_to_new_prefix={ + # vision backbone mapping + "model.vision_backbone.": "vision_backbone.", + # language backbone mapping + "model.transformer.blocks.": "model.layers.", + "model.transformer.ln_f.": "model.norm.", + # lm_head is renamed to model.transformer.mlp.down_proj firstly, + # we need to run a second renaming for it + "model.transformer.mlp.down_proj.": "lm_head.", + }, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -1298,36 +1326,10 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={ - # vision backbone mapping - "image_projector.w1.": "image_projector.gate_proj.", - "image_projector.w3.": "image_projector.up_proj.", - "image_projector.w2.": "image_projector.down_proj.", - # language backbone mapping - "att_proj": "self_attn.qkv_proj", - "attn_out": "self_attn.o_proj", - "q_norm": "self_attn.q_norm", - "k_norm": "self_attn.k_norm", - "ff_proj": "mlp.gate_up_proj", - "ff_out": "mlp.down_proj", - "attn_norm": "input_layernorm", - "ff_norm": "post_attention_layernorm", - }, - orig_to_new_prefix={ - # vision backbone mapping - "model.vision_backbone.": "vision_backbone.", - # language backbone mapping - "model.transformer.blocks.": "model.layers.", - "model.transformer.ln_f.": "model.norm.", - # lm_head is renamed to model.transformer.mlp.down_proj firstly, - # we need to run a second renaming for it - "model.transformer.mlp.down_proj.": "lm_head.", - }, - ) + loader = AutoWeightsLoader(self) weights = _get_weights_with_merged_embedding(weights) - return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def _get_weights_with_merged_embedding( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index e2263f63f7bba..4e2e7f5761544 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -408,6 +408,13 @@ def _get_dummy_mm_inputs( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_embed_tokens.wte": "embed_tokens", + "model.vision_embed_tokens.": "vision_embed_tokens.", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -616,17 +623,10 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.vision_embed_tokens.wte": "embed_tokens", - "model.vision_embed_tokens.": "vision_embed_tokens.", - "lm_head.": "language_model.lm_head.", - "model.": "language_model.model.", - }) loader = AutoWeightsLoader(self) autoloaded_weights = loader.load_weights(weights, - mapper=hf_to_vllm_mapper) + mapper=self.hf_to_vllm_mapper) # The HF config doesn't specify whether these are tied, # so we detect it this way diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3ce4eb5869f21..7661bb285df95 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -529,6 +529,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -577,8 +579,7 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = hf_to_vllm_mapper.apply(weights) + weights = self.hf_to_vllm_mapper.apply(weights) weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) self.model.load_weights(weights) diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index 39c9103527f01..28c37bb96612c 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -31,6 +31,19 @@ class TeleChat2Model(LlamaModel): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "transformer.": "model.", + }, + orig_to_new_substr={ + ".h.": ".layers.", + ".self_attention.": ".self_attn.", + ".word_embeddings.": ".embed_tokens.", + ".dense.": ".o_proj.", + ".ln_f.": ".norm.", + }, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # 1. Initialize the LlamaModel with bias vllm_config.model_config.hf_config.bias = True @@ -111,21 +124,9 @@ def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "transformer.": "model.", - }, - orig_to_new_substr={ - ".h.": ".layers.", - ".self_attention.": ".self_attn.", - ".word_embeddings.": ".embed_tokens.", - ".dense.": ".o_proj.", - ".ln_f.": ".norm.", - }, - ) loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c60b208c3d27d..509ad9e580ddf 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -302,6 +302,9 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -494,9 +497,7 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."]) - return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)