Skip to content

Commit

Permalink
Merge pull request #3835 from BUAADreamer/main
Browse files Browse the repository at this point in the history
fix some features in llava-style training
  • Loading branch information
hiyouga authored May 27, 2024
2 parents e626e26 + 576b020 commit 838f2fb
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
14 changes: 14 additions & 0 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@
"assistant_tag": "assistant"
}
},
"mllm_pt_demo": {
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"
Expand Down
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class ModelArguments:
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
tune_mm_proj: bool = field(
default=False,
metadata={"help": "Whethor or not only finetune mm_projector for MLLM."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
Expand Down
7 changes: 7 additions & 0 deletions src/llamafactory/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .utils.visual import filter_vision_tower_linear


if TYPE_CHECKING:
Expand Down Expand Up @@ -58,6 +59,9 @@ def init_adapter(
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)

if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj
model.language_model.requires_grad_(False)

if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = (
Expand Down Expand Up @@ -180,6 +184,9 @@ def init_adapter(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)

if model_args.visual_inputs:
target_modules = filter_vision_tower_linear(target_modules)

if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
Expand Down
7 changes: 6 additions & 1 deletion src/llamafactory/model/utils/visual.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Tuple, List

import torch
import transformers.models
Expand Down Expand Up @@ -82,3 +82,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL


def filter_vision_tower_linear(target_modules: List[str]) -> str:
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
return target_modules

0 comments on commit 838f2fb

Please sign in to comment.