Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][LoRA]LoRA support added for MolmoForCausalLM #11439

Merged
merged 12 commits into from
Dec 31, 2024
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ See [this page](#generative-models) for more information on how to use generativ
- Molmo
- T + I
- `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc.
-
- ✅︎
- ✅︎
- ✅︎
* - `NVLM_D_Model`
Expand Down
45 changes: 42 additions & 3 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -1161,8 +1162,8 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):

class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
Expand Down Expand Up @@ -1191,6 +1192,32 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
},
)

packed_modules_mapping = {
"qkv_proj": ["qkv_proj"],
"gate_up_proj": ["gate_up_proj"], # language model
"merged_linear": ["gate_proj", "up_proj"] # image_projector
}

# LoRA specific attributes
supported_lora_modules = [
# language model
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj", # same name with image_projector
# vision tower
"wq",
"wk",
"wv",
"wo",
"w1",
"w2",
# image_projector
"merged_linear",
]
embedding_modules = {}
embedding_padding_modules = []

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0),
Expand All @@ -1202,8 +1229,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
lora_config = vllm_config.lora_config
self.config = config
self.multimodal_config = multimodal_config
self.lora_config = lora_config

vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(config, vision_config,
Expand Down Expand Up @@ -1377,6 +1406,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model",
connector="vision_backbone.image_projector",
tower_model="vision_backbone",
)


def _get_weights_with_merged_embedding(
weights: Iterable[Tuple[str, torch.Tensor]]
Expand Down
Loading