Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLLM YI-VL and save processor config during training #3748

Merged
merged 16 commits into from
May 15, 2024
Merged
15 changes: 15 additions & 0 deletions src/llmtuner/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,21 @@ def get_template_and_fix_tokenizer(
)


_register_template(
name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n"
),
stop_words=["###"],
)


_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
Expand Down
4 changes: 2 additions & 2 deletions src/llmtuner/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.valuehead import prepare_valuehead_model
from .utils.visual import autocast_projector_dtype, configure_hidden_size
from .utils.visual import autocast_projector_dtype, configure_visual_model


if TYPE_CHECKING:
Expand Down Expand Up @@ -54,7 +54,7 @@ def patch_config(
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_hidden_size(config)
configure_visual_model(config)

if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
Expand Down
32 changes: 28 additions & 4 deletions src/llmtuner/model/utils/visual.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,38 @@
from typing import TYPE_CHECKING, Tuple

import torch
import transformers.models
from transformers.activations import ACT2FN

from ...extras.logging import get_logger


if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel

from ...hparams import ModelArguments


logger = get_logger(__name__)


def configure_hidden_size(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
class LlavaMultiModalProjector(torch.nn.Module):
def __init__(self, config: "LlavaConfig"):
super().__init__()

self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]

def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_3(hidden_states)
hidden_states = self.linear_4(hidden_states)
return hidden_states


def autocast_projector_dtype(
Expand All @@ -31,3 +47,11 @@ def _mm_projector_forward_post_hook(
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)


def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))

if getattr(config, "is_yi_vl_derived_model", None):
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjector
12 changes: 11 additions & 1 deletion src/llmtuner/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput

from ...hparams import FinetuningArguments
Expand All @@ -26,9 +27,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""

def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self.processor = processor
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

Expand All @@ -45,6 +49,12 @@ def create_scheduler(
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)

def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)

def prediction_step(
self,
model: "torch.nn.Module",
Expand Down
2 changes: 1 addition & 1 deletion src/llmtuner/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def run_sft(
model=model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
)

Expand Down