diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 42f8f20072..9817318deb 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -156,7 +156,7 @@ def _process_args( if image is not None: mm_inputs = template.mm_plugin.get_mm_inputs( - images=[image], feature_seqlens={"token_type_ids": prompt_length}, processor=processor + images=[image], imglens=[1], seqlens=[prompt_length], processor=processor ) for key, value in mm_inputs.items(): value = value if isinstance(value, torch.Tensor) else torch.tensor(value) diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index a3440ff55f..05ff7ef0bb 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -14,9 +14,7 @@ import os from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union - -from datasets import Features +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from ..extras.logging import get_logger from .data_utils import Role @@ -27,16 +25,24 @@ from transformers import Seq2SeqTrainingArguments from ..hparams import DataArguments + from .mm_plugin import ImageInput from .parser import DatasetAttr logger = get_logger(__name__) -def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: +def _convert_images( + images: Sequence["ImageInput"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", +) -> Optional[List["ImageInput"]]: r""" Optionally concatenates image path to dataset dir when loading from local disk. """ + if len(images) == 0: + return None + images = images[:] if dataset_attr.load_from in ["script", "file"]: for i in range(len(images)): @@ -47,66 +53,67 @@ def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_arg def convert_alpaca( - examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" -) -> Dict[str, List[Any]]: + example: Dict[str, Any], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", +) -> Dict[str, Any]: r""" Converts alpaca format dataset to the standard format. """ - outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} + prompt = [] + if dataset_attr.history and isinstance(example[dataset_attr.history], list): + for old_prompt, old_response in example[dataset_attr.history]: + prompt.append({"role": Role.USER.value, "content": old_prompt}) + prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) + + query = [] + if dataset_attr.prompt and example[dataset_attr.prompt]: + query.append(example[dataset_attr.prompt]) + + if dataset_attr.query and example[dataset_attr.query]: + query.append(example[dataset_attr.query]) + + prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery" + + if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example + response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}] + if example[dataset_attr.kto_tag]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + dataset_attr.ranking + and isinstance(example[dataset_attr.chosen], str) + and isinstance(example[dataset_attr.rejected], str) + ): # pairwise example + response = [ + {"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]}, + {"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]}, + ] + elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example + response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}] + else: # unsupervised + response = [] + convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) - for i in range(len(examples[dataset_attr.prompt])): - prompt = [] - if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): - for old_prompt, old_response in examples[dataset_attr.history][i]: - prompt.append({"role": Role.USER.value, "content": old_prompt}) - prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) - - content = [] - if dataset_attr.prompt and examples[dataset_attr.prompt][i]: - content.append(examples[dataset_attr.prompt][i]) - - if dataset_attr.query and examples[dataset_attr.query][i]: - content.append(examples[dataset_attr.query][i]) - - prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery" - - if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example - response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] - if examples[dataset_attr.kto_tag][i]: - response = response + [{"role": Role.ASSISTANT.value, "content": ""}] - else: - response = [{"role": Role.ASSISTANT.value, "content": ""}] + response - elif ( - dataset_attr.ranking - and isinstance(examples[dataset_attr.chosen][i], str) - and isinstance(examples[dataset_attr.rejected][i], str) - ): # pairwise example - response = [ - {"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]}, - {"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]}, - ] - elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example - response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] - else: # unsupervised - response = [] - - outputs["prompt"].append(prompt) - outputs["response"].append(response) - outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") - outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") - outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) - - return outputs + output = { + "_prompt": prompt, + "_response": response, + "_system": example[dataset_attr.system] if dataset_attr.system else "", + "_tools": example[dataset_attr.tools] if dataset_attr.tools else "", + "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None, + } + return output def convert_sharegpt( - examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" -) -> Dict[str, List[Any]]: + example: Dict[str, Any], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", +) -> Dict[str, Any]: r""" Converts sharegpt format dataset to the standard format. """ - outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} - convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) tag_mapping = { dataset_attr.user_tag: Role.USER.value, dataset_attr.assistant_tag: Role.ASSISTANT.value, @@ -117,74 +124,77 @@ def convert_sharegpt( odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) accept_tags = (odd_tags, even_tags) - for i, messages in enumerate(examples[dataset_attr.messages]): - if len(messages) == 0: - continue - - if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag: - system = messages[0][dataset_attr.content_tag] - messages = messages[1:] - else: - system = examples[dataset_attr.system][i] if dataset_attr.system else "" + messages = example[dataset_attr.messages] + if ( + dataset_attr.system_tag + and len(messages) != 0 + and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag + ): + system = messages[0][dataset_attr.content_tag] + messages = messages[1:] + else: + system = example[dataset_attr.system] if dataset_attr.system else "" - aligned_messages = [] - broken_data = False - for turn_idx, message in enumerate(messages): - if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: - logger.warning("Invalid role tag in {}.".format(messages)) - broken_data = True + aligned_messages = [] + broken_data = False + for turn_idx, message in enumerate(messages): + if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: + logger.warning("Invalid role tag in {}.".format(messages)) + broken_data = True - aligned_messages.append( - {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} - ) + aligned_messages.append( + {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} + ) - if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( - dataset_attr.ranking and len(aligned_messages) % 2 == 0 + if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( + dataset_attr.ranking and len(aligned_messages) % 2 == 0 + ): + logger.warning("Invalid message count in {}.".format(messages)) + broken_data = True + + if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + if example[dataset_attr.kto_tag]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + dataset_attr.ranking + and isinstance(example[dataset_attr.chosen], dict) + and isinstance(example[dataset_attr.rejected], dict) + ): # pairwise example + chosen = example[dataset_attr.chosen] + rejected = example[dataset_attr.rejected] + if ( + chosen[dataset_attr.role_tag] not in accept_tags[-1] + or rejected[dataset_attr.role_tag] not in accept_tags[-1] ): - logger.warning("Invalid message count in {}.".format(messages)) + logger.warning("Invalid role tag in {}.".format([chosen, rejected])) broken_data = True - if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example - prompt = aligned_messages[:-1] - response = aligned_messages[-1:] - if examples[dataset_attr.kto_tag][i]: - response = response + [{"role": Role.ASSISTANT.value, "content": ""}] - else: - response = [{"role": Role.ASSISTANT.value, "content": ""}] + response - elif ( - dataset_attr.ranking - and isinstance(examples[dataset_attr.chosen][i], dict) - and isinstance(examples[dataset_attr.rejected][i], dict) - ): # pairwise example - chosen = examples[dataset_attr.chosen][i] - rejected = examples[dataset_attr.rejected][i] - if ( - chosen[dataset_attr.role_tag] not in accept_tags[-1] - or rejected[dataset_attr.role_tag] not in accept_tags[-1] - ): - logger.warning("Invalid role tag in {}.".format([chosen, rejected])) - broken_data = True - - prompt = aligned_messages - response = [ - {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, - {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]}, - ] - else: # normal example - prompt = aligned_messages[:-1] - response = aligned_messages[-1:] - - if broken_data: - logger.warning("Skipping this abnormal example.") - continue - - outputs["prompt"].append(prompt) - outputs["response"].append(response) - outputs["system"].append(system) - outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") - outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) - - return outputs + prompt = aligned_messages + response = [ + {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, + {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]}, + ] + else: # normal example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + + if broken_data: + logger.warning("Skipping this abnormal example.") + prompt, response = [], [] + + convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) + output = { + "_prompt": prompt, + "_response": response, + "_system": system, + "_tools": example[dataset_attr.tools] if dataset_attr.tools else "", + "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None, + } + return output def align_dataset( @@ -195,11 +205,11 @@ def align_dataset( ) -> Union["Dataset", "IterableDataset"]: r""" Aligned dataset: - prompt: [{"role": "user", "content": "..."}] * (2T - 1) - response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) - system: "..." - tools: "...", - images: [], + _prompt: [{"role": "user", "content": "..."}] * (2T - 1) + _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) + _system: "..." + _tools: "...", + _images: [], """ if dataset_attr.formatting == "alpaca": convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) @@ -207,19 +217,6 @@ def align_dataset( convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args) column_names = list(next(iter(dataset)).keys()) - features = Features.from_dict( - { - "prompt": [ - {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} - ], - "response": [ - {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} - ], - "system": {"dtype": "string", "_type": "Value"}, - "tools": {"dtype": "string", "_type": "Value"}, - "images": [{"_type": "Image"}], - } - ) kwargs = {} if not data_args.streaming: kwargs = dict( @@ -230,8 +227,7 @@ def align_dataset( return dataset.map( convert_func, - batched=True, + batched=False, remove_columns=column_names, - features=features, **kwargs, ) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index eecf905290..73508b470d 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -16,12 +16,18 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, Literal, Sequence +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence import torch from transformers import DataCollatorForSeq2Seq +if TYPE_CHECKING: + from transformers import ProcessorMixin + + from .template import Template + + def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": r""" Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), @@ -65,41 +71,29 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): r""" Data collator that supports VLMs. - """ - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: - if "token_type_ids" in features[0].keys(): - for feature in features: - feature["token_type_ids"] = feature["token_type_ids"][0] - - extra_features = {} - if "pixel_values" in features[0].keys(): - pixel_values = [] - for feature in features: - if feature["pixel_values"] is None: - pixel_values.append(torch.zeros(0, dtype=torch.float)) - else: - pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float)) + Features should contain input_ids, attention_mask, labels and images. + """ - extra_features["pixel_values"] = torch.cat(pixel_values, dim=0) - if extra_features["pixel_values"].numel() == 0: - extra_features["pixel_values"] = None + template: Optional["Template"] = None + processor: Optional["ProcessorMixin"] = None - if "image_grid_thw" in features[0].keys(): - image_grid_thw = [] - for feature in features: - if feature["image_grid_thw"] is None: - image_grid_thw.append(torch.zeros(0, dtype=torch.long)) - else: - image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long)) + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + batch_images, batch_imglens, batch_seqlens = [], [], [] + for feature in features: + images = feature.pop("images") or [] # avoid NoneType + batch_images.extend(images) + batch_imglens.append(len(images)) + batch_seqlens.append(len(feature["input_ids"])) - extra_features["image_grid_thw"] = torch.cat(image_grid_thw, dim=0) - if extra_features["image_grid_thw"].numel() == 0: - extra_features["image_grid_thw"] = None + mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor) + if "token_type_ids" in mm_inputs: + token_type_ids = mm_inputs.pop("token_type_ids") + for i, feature in enumerate(features): + feature["token_type_ids"] = token_type_ids[i] - features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features] features: Dict[str, "torch.Tensor"] = super().__call__(features) - features.update({key: value for key, value in extra_features.items() if value is not None}) + features.update(mm_inputs) return features @@ -141,16 +135,8 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso "input_ids": feature["{}_input_ids".format(key)], "attention_mask": feature["{}_attention_mask".format(key)], "labels": feature["{}_labels".format(key)], + "images": feature["images"], } - if "{}_token_type_ids".format(key) in feature: - target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] - - if "pixel_values" in feature: # image data are same for chosen and rejected - target_feature["pixel_values"] = feature["pixel_values"] - - if "image_grid_thw" in feature: - target_feature["image_grid_thw"] = feature["image_grid_thw"] - concatenated_features.append(target_feature) return super().__call__(concatenated_features) @@ -171,22 +157,14 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso "input_ids": feature["input_ids"], "attention_mask": feature["attention_mask"], "labels": feature["labels"], + "images": feature["images"], } kl_feature = { "input_ids": feature["kl_input_ids"], "attention_mask": feature["kl_attention_mask"], "labels": feature["kl_labels"], + "images": feature["images"], } - if "token_type_ids" in feature: - target_feature["token_type_ids"] = feature["token_type_ids"] - kl_feature["token_type_ids"] = feature["kl_token_type_ids"] - - if "pixel_values" in feature: - target_feature["pixel_values"] = feature["pixel_values"] - - if "image_grid_thw" in feature: - target_feature["image_grid_thw"] = feature["image_grid_thw"] - target_features.append(target_feature) kl_features.append(kl_feature) kto_tags.append(feature["kto_tags"]) @@ -196,7 +174,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_labels"] = kl_batch["labels"] - if "token_type_ids" in batch: + if "token_type_ids" in kl_batch: batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kto_tags"] = torch.tensor(kto_tags) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 069ea1997e..f24c6cdb9e 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -14,7 +14,7 @@ import os import sys -from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Tuple, Union import numpy as np from datasets import DatasetDict, load_dataset, load_from_disk @@ -180,7 +180,13 @@ def _get_preprocessed_dataset( desc="Running tokenizer on dataset", ) - dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) + dataset = dataset.map( + preprocess_func, + batched=True, + batch_size=data_args.preprocessing_batch_size, + remove_columns=column_names, + **kwargs, + ) if training_args.should_log: try: @@ -202,7 +208,7 @@ def get_dataset( stage: Literal["pt", "sft", "rm", "ppo", "kto"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, -) -> "DatasetModule": +) -> Tuple["DatasetModule", "Template"]: template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") @@ -273,4 +279,4 @@ def get_dataset( if "validation" in dataset_dict: dataset_module["eval_dataset"] = dataset_dict["validation"] - return dataset_module + return dataset_module, template diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index c0ac8e4210..5e1b5bd86c 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from io import BytesIO +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from PIL.Image import Image from transformers import ProcessorMixin @@ -9,34 +10,53 @@ if is_pillow_available(): - import torch from PIL import Image + from PIL.Image import Image as ImageObject if TYPE_CHECKING: - from PIL.Image import Image as ImageObject + import torch from transformers import PreTrainedTokenizer, ProcessorMixin from transformers.image_processing_utils import BaseImageProcessor + class EncodedImage(TypedDict): + path: Optional[str] + bytes: Optional[bytes] + + ImageInput = Union[str, EncodedImage, ImageObject] + -def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]: +def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]: r""" - Regularizes images to avoid error. Including resizing and mode convert. + Regularizes images to avoid error. Including reading, resizing and converting. """ - images = images[:] image_resolution = getattr(processor, "image_resolution", 512) - for i in range(len(images)): - if max(images[i].width, images[i].height) > image_resolution: - factor = image_resolution / max(images[i].width, images[i].height) - images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor))) + results = [] + for image in images: + if isinstance(image, str): + image = Image.open(image) + elif isinstance(image, dict): + if image["bytes"] is not None: + image = Image.open(BytesIO(image["bytes"])) + else: + image = Image.open(image["path"]) - if images[i].mode != "RGB": - images[i] = images[i].convert("RGB") + if not isinstance(image, ImageObject): + raise ValueError("Expect input is a list of Images, but got {}.".format(type(image))) - return images + if max(image.width, image.height) > image_resolution: + factor = image_resolution / max(image.width, image.height) + image = image.resize((int(image.width * factor), int(image.height * factor))) + if image.mode != "RGB": + image = image.convert("RGB") -def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: + results.append(image) + + return results + + +def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: r""" Processes visual inputs. @@ -53,26 +73,27 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") if len(images) != 0: images = _regularize_images(images, processor) image_inputs = image_processor(images=images, return_tensors="pt") - else: # add NoneType for fake images - image = Image.new("RGB", (64, 64), (255, 255, 255)) - image_inputs = image_processor(images=[image], return_tensors="pt") - image_inputs = {key: None for key in image_inputs.keys()} + else: + image_inputs = {} return image_inputs def _get_paligemma_token_type_ids( - images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin" + imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" ) -> List[List[int]]: r""" Gets paligemma token type ids for computing loss. Returns: - token_type_ids: shape (1, seq_len) + batch_token_type_ids: shape (batch_size, sequence_length) """ - num_images = len(images) - image_seqlen = num_images * getattr(processor, "image_seqlen") - return [[0] * image_seqlen + [1] * (input_len - image_seqlen)] + batch_token_type_ids = [] + for imglen, seqlen in zip(imglens, seqlens): + image_seqlen = imglen * getattr(processor, "image_seqlen") + batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) + + return batch_token_type_ids class BasePlugin: @@ -82,7 +103,7 @@ def __init__(self, image_token: str) -> None: def process_messages( self, messages: Sequence[Dict[str, str]], - images: Sequence["ImageObject"], + images: Sequence["ImageInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: r""" @@ -94,7 +115,7 @@ def process_token_ids( self, input_ids: List[int], labels: Optional[List[int]], - images: Sequence["ImageObject"], + images: Sequence["ImageInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[List[int], Optional[List[int]]]: @@ -105,10 +126,11 @@ def process_token_ids( def get_mm_inputs( self, - images: Sequence["ImageObject"], - feature_seqlens: Dict[str, int], + images: Sequence["ImageInput"], + imglens: Sequence[int], + seqlens: Sequence[int], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Any]: + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: r""" Builds batched multimodal inputs for VLMs. """ @@ -119,31 +141,32 @@ class LlavaPlugin(BasePlugin): def process_messages( self, messages: Sequence[Dict[str, str]], - images: Sequence["ImageObject"], + images: Sequence["ImageInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - num_images = 0 + num_image_tokens = 0 image_seqlen = getattr(processor, "image_seqlen") messages = deepcopy(messages) for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - num_images += 1 + num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) - if len(images) != num_images: + if len(images) != num_image_tokens: raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) return messages def get_mm_inputs( self, - images: Sequence["ImageObject"], - feature_seqlens: Dict[str, int], + images: Sequence["ImageInput"], + imglens: Sequence[int], + seqlens: Sequence[int], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Any]: + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: return _get_mm_inputs(images, processor) @@ -151,20 +174,20 @@ class PaliGemmaPlugin(BasePlugin): def process_messages( self, messages: Sequence[Dict[str, str]], - images: Sequence["ImageObject"], + images: Sequence["ImageInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - num_images = 0 + num_image_tokens = 0 messages = deepcopy(messages) for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - num_images += 1 + num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) message["content"] = content.replace("{{image}}", "") - if len(images) != num_images: + if len(images) != num_image_tokens: raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) return messages @@ -173,7 +196,7 @@ def process_token_ids( self, input_ids: List[int], labels: Optional[List[int]], - images: Sequence["ImageObject"], + images: Sequence["ImageInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[List[int], Optional[List[int]]]: @@ -188,14 +211,13 @@ def process_token_ids( def get_mm_inputs( self, - images: Sequence["ImageObject"], - feature_seqlens: Dict[str, int], + images: Sequence["ImageInput"], + imglens: Sequence[int], + seqlens: Sequence[int], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Any]: + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: mm_inputs = _get_mm_inputs(images, processor) - for feature_name, feature_length in feature_seqlens.items(): - mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor) - + mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) return mm_inputs @@ -203,7 +225,7 @@ class Qwen2vlPlugin(BasePlugin): def process_messages( self, messages: Sequence[Dict[str, str]], - images: Sequence["ImageObject"], + images: Sequence["ImageInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") @@ -213,36 +235,37 @@ def process_messages( else: image_grid_thw = [] - num_images = 0 + num_image_tokens = 0 messages = deepcopy(messages) for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - if num_images >= len(image_grid_thw): + if num_image_tokens >= len(image_grid_thw): raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER)) content = content.replace( IMAGE_PLACEHOLDER, "<|vision_start|>{}<|vision_end|>".format( - self.image_token * (image_grid_thw[num_images].prod() // merge_length) + self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length) ), 1, ) - num_images += 1 + num_image_tokens += 1 message["content"] = content - if len(images) != num_images: + if len(images) != num_image_tokens: raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) return messages def get_mm_inputs( self, - images: Sequence["ImageObject"], - feature_seqlens: Dict[str, int], + images: Sequence["ImageInput"], + imglens: Sequence[int], + seqlens: Sequence[int], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Any]: + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: return _get_mm_inputs(images, processor) diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 19539f3c86..045182d937 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -21,10 +21,10 @@ if TYPE_CHECKING: - from PIL.Image import Image from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments + from ..mm_plugin import ImageInput from ..template import Template @@ -37,12 +37,12 @@ def _encode_feedback_example( kl_response: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], - images: Sequence["Image"], + images: Sequence["ImageInput"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, -) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]: +) -> Tuple[List[int], List[int], List[int], List[int], bool]: if response[0]["content"]: # desired example kto_tag = True messages = prompt + [response[0]] @@ -78,15 +78,7 @@ def _encode_feedback_example( labels = [IGNORE_INDEX] * source_len + response_ids kl_input_ids = kl_prompt_ids + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids - extra_inputs = template.mm_plugin.get_mm_inputs( - images=images, - feature_seqlens={ - "token_type_ids": len(input_ids), - "kl_token_type_ids": len(kl_input_ids), - }, - processor=processor, - ) - return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs + return input_ids, labels, kl_input_ids, kl_labels, kto_tag def preprocess_feedback_dataset( @@ -97,20 +89,20 @@ def preprocess_feedback_dataset( data_args: "DataArguments", ) -> Dict[str, List[Any]]: # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs - kl_response = examples["response"][::-1] + kl_response = examples["_response"][::-1] model_inputs = defaultdict(list) - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) continue - input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example( - prompt=examples["prompt"][i], - response=examples["response"][i], + input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], kl_response=kl_response[i], - system=examples["system"][i], - tools=examples["tools"][i], - images=examples["images"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], template=template, tokenizer=tokenizer, processor=processor, @@ -123,8 +115,7 @@ def preprocess_feedback_dataset( model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_labels"].append(kl_labels) model_inputs["kto_tags"].append(kto_tag) - for key, value in extra_inputs.items(): - model_inputs[key].append(value) + model_inputs["images"].append(examples["_images"][i]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) undesirable_num = len(model_inputs["kto_tags"]) - desirable_num diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index 9c5565d963..fa7e3fd2fb 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -21,10 +21,10 @@ if TYPE_CHECKING: - from PIL.Image import Image from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments + from ..mm_plugin import ImageInput from ..template import Template @@ -36,12 +36,12 @@ def _encode_pairwise_example( response: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], - images: Sequence["Image"], + images: Sequence["ImageInput"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, -) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]: +) -> Tuple[List[int], List[int], List[int], List[int]]: chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor) rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) @@ -62,15 +62,7 @@ def _encode_pairwise_example( chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids rejected_input_ids = prompt_ids + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids - extra_inputs = template.mm_plugin.get_mm_inputs( - images=images, - feature_seqlens={ - "chosen_token_type_ids": len(chosen_input_ids), - "rejected_token_type_ids": len(rejected_input_ids), - }, - processor=processor, - ) - return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs + return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels def preprocess_pairwise_dataset( @@ -82,17 +74,17 @@ def preprocess_pairwise_dataset( ) -> Dict[str, List[Any]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = defaultdict(list) - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) continue - chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example( - prompt=examples["prompt"][i], - response=examples["response"][i], - system=examples["system"][i], - tools=examples["tools"][i], - images=examples["images"][i], + chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], template=template, tokenizer=tokenizer, processor=processor, @@ -104,8 +96,7 @@ def preprocess_pairwise_dataset( model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_labels"].append(rejected_labels) - for key, value in extra_inputs.items(): - model_inputs[key].append(value) + model_inputs["images"].append(examples["_images"][i]) return model_inputs diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py index 9342225964..77282bad65 100644 --- a/src/llamafactory/data/processors/pretrain.py +++ b/src/llamafactory/data/processors/pretrain.py @@ -30,7 +30,7 @@ def preprocess_pretrain_dataset( ) -> Dict[str, List[Any]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token - text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] + text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] if not data_args.packing: if data_args.template == "gemma": diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index f488dfcac1..00e5ed4407 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -21,10 +21,10 @@ if TYPE_CHECKING: - from PIL.Image import Image from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments + from ..mm_plugin import ImageInput from ..template import Template @@ -36,14 +36,14 @@ def _encode_supervised_example( response: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], - images: Sequence["Image"], + images: Sequence["ImageInput"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, train_on_prompt: bool, mask_history: bool, -) -> Tuple[List[int], List[int], Dict[str, Any]]: +) -> Tuple[List[int], List[int]]: messages = template.mm_plugin.process_messages(prompt + response, images, processor) input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) @@ -83,10 +83,7 @@ def _encode_supervised_example( input_ids += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id] - extra_inputs = template.mm_plugin.get_mm_inputs( - images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor - ) - return input_ids, labels, extra_inputs + return input_ids, labels def preprocess_supervised_dataset( @@ -99,17 +96,17 @@ def preprocess_supervised_dataset( # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. model_inputs = defaultdict(list) - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) continue - input_ids, labels, extra_inputs = _encode_supervised_example( - prompt=examples["prompt"][i], - response=examples["response"][i], - system=examples["system"][i], - tools=examples["tools"][i], - images=examples["images"][i], + input_ids, labels = _encode_supervised_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], template=template, tokenizer=tokenizer, processor=processor, @@ -120,8 +117,7 @@ def preprocess_supervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) - for key, value in extra_inputs.items(): - model_inputs[key].append(value) + model_inputs["images"].append(examples["_images"][i]) return model_inputs @@ -143,17 +139,17 @@ def preprocess_packed_supervised_dataset( batch_input_ids, batch_labels = [], [] lengths = [] length2indexes = defaultdict(list) - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) continue - input_ids, labels, _ = _encode_supervised_example( - prompt=examples["prompt"][i], - response=examples["response"][i], - system=examples["system"][i], - tools=examples["tools"][i], - images=examples["images"][i], + input_ids, labels = _encode_supervised_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], template=template, tokenizer=tokenizer, processor=None, @@ -199,6 +195,7 @@ def preprocess_packed_supervised_dataset( model_inputs["input_ids"].append(packed_input_ids) model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["labels"].append(packed_labels) + model_inputs["images"].append(examples["_images"][i]) return model_inputs diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index 67cbb7b649..6f251969ef 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -21,10 +21,10 @@ if TYPE_CHECKING: - from PIL.Image import Image from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments + from ..mm_plugin import ImageInput from ..template import Template @@ -36,12 +36,12 @@ def _encode_unsupervised_example( response: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], - images: Sequence["Image"], + images: Sequence["ImageInput"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, -) -> Tuple[List[int], List[int], Dict[str, Any]]: +) -> Tuple[List[int], List[int]]: if len(response) == 1: messages = prompt + response else: @@ -56,10 +56,7 @@ def _encode_unsupervised_example( source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) input_ids = input_ids[:source_len] labels = labels[:target_len] - extra_inputs = template.mm_plugin.get_mm_inputs( - images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor - ) - return input_ids, labels, extra_inputs + return input_ids, labels def preprocess_unsupervised_dataset( @@ -71,17 +68,17 @@ def preprocess_unsupervised_dataset( ) -> Dict[str, List[Any]]: # build inputs with format ` X` and labels with format `Y ` model_inputs = defaultdict(list) - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1: + logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) continue - input_ids, labels, extra_inputs = _encode_unsupervised_example( - prompt=examples["prompt"][i], - response=examples["response"][i], - system=examples["system"][i], - tools=examples["tools"][i], - images=examples["images"][i], + input_ids, labels = _encode_unsupervised_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], template=template, tokenizer=tokenizer, processor=processor, @@ -90,8 +87,7 @@ def preprocess_unsupervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) - for key, value in extra_inputs.items(): - model_inputs[key].append(value) + model_inputs["images"].append(examples["_images"][i]) return model_inputs diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 0cb4a56df0..1adcf2d0df 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -73,6 +73,10 @@ class DataArguments: default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}, ) + preprocessing_batch_size: int = field( + default=1000, + metadata={"help": "The number of examples in one group in pre-processing."}, + ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the pre-processing."}, diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index f474a90f26..5135f5a2ec 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -41,13 +41,14 @@ def run_dpo( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) + dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = PairwiseDataCollatorWithPadding( - tokenizer=tokenizer, + template=template, pad_to_multiple_of=8, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + **tokenizer_module, ) # Create reference model @@ -60,7 +61,7 @@ def run_dpo( ref_model = None # Update arguments - training_args.remove_unused_columns = False # important for pairwise dataset + training_args.remove_unused_columns = False # important for multimodal and pairwise dataset # Initialize our Trainer trainer = CustomDPOTrainer( diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index fa85de37c5..8d282685a6 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -41,13 +41,14 @@ def run_kto( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) + dataset_module, template = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = KTODataCollatorWithPadding( - tokenizer=tokenizer, + template=template, pad_to_multiple_of=8, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + **tokenizer_module, ) # Create reference model @@ -57,7 +58,7 @@ def run_kto( ref_model = create_ref_model(model_args, finetuning_args) # Update arguments - training_args.remove_unused_columns = False # important for pairwise dataset + training_args.remove_unused_columns = False # important for multimodal and pairwise dataset # Initialize our Trainer trainer = CustomKTOTrainer( diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 94c4320d84..7fa5c252a2 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -41,11 +41,11 @@ def run_ppo( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) + dataset_module, template = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training - data_collator = MultiModalDataCollatorForSeq2Seq(tokenizer=tokenizer) + data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module) # Create reference model and reward model ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 1052a9d193..91c66fa98f 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -42,7 +42,7 @@ def run_pt( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) + dataset_module, _ = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index f0afd7dcc8..9adf582775 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -41,12 +41,12 @@ def run_rm( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) + dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) - data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module) # Update arguments - training_args.remove_unused_columns = False # important for pairwise dataset + training_args.remove_unused_columns = False # important for multimodal and pairwise dataset # Initialize our Trainer trainer = PairwiseTrainer( diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 5e3787f148..a577e8796f 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -43,24 +43,26 @@ def run_sft( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) + dataset_module, template = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction data_collator = SFTDataCollatorWith4DAttentionMask( - tokenizer=tokenizer, + template=template, pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, block_diag_attn=model_args.block_diag_attn, attn_implementation=getattr(model.config, "_attn_implementation", None), compute_dtype=model_args.compute_dtype, + **tokenizer_module, ) # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams + training_args.remove_unused_columns = False # important for multimodal and pairwise dataset # Metric utils metric_module = {} diff --git a/src/llamafactory/train/test_utils.py b/src/llamafactory/train/test_utils.py index fedc873d04..2a3b3eeeee 100644 --- a/src/llamafactory/train/test_utils.py +++ b/src/llamafactory/train/test_utils.py @@ -105,7 +105,7 @@ def load_reference_model( def load_train_dataset(**kwargs) -> "Dataset": model_args, data_args, training_args, _, _ = get_train_args(kwargs) tokenizer_module = load_tokenizer(model_args) - dataset_module = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module) + dataset_module, _ = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module) return dataset_module["train_dataset"] diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index c2950605f6..a40080ec07 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -47,11 +47,15 @@ NO_IMAGES = [] +IMGLENS = [1] + +NO_IMGLENS = [0] + INPUT_IDS = [0, 1, 2, 3, 4] LABELS = [0, 1, 2, 3, 4] -FEATURE_SEQLENS = {"token_type_ids": 1024} +SEQLENS = [1024] def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: @@ -80,11 +84,11 @@ def test_base_plugin(): # test mm_messages assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) - _is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {}) + _is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {}) # test text_messages assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) - _is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {}) + _is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {}) def test_llava_plugin(): @@ -101,11 +105,11 @@ def test_llava_plugin(): # test mm_messages assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) - _is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) + _is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs) # test text_messages assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) - _is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None}) + _is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {}) @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @@ -128,7 +132,7 @@ def test_paligemma_plugin(): expected_input_ids, expected_labels, ) - _is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) + _is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs) # test text_messages assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == ( @@ -136,8 +140,8 @@ def test_paligemma_plugin(): LABELS, ) _is_close( - paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), - {"pixel_values": None, "token_type_ids": [[1] * 1024]}, + paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), + {"token_type_ids": [[1] * 1024]}, ) @@ -158,11 +162,8 @@ def test_qwen2_vl_plugin(): # test mm_messages assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) - _is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) + _is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs) # test text_messages assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) - _is_close( - qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), - {"pixel_values": None, "image_grid_thw": None}, - ) + _is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})