Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support qwen2-vl #5290

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@
"assistant_tag": "assistant"
}
},
"qwen2vl_demo": {
"file_name": "qwen2vl_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"
Expand Down
140 changes: 140 additions & 0 deletions data/qwen2vl_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
[
{
"messages": [
{
"content": "<|image_pad|>Who are they?",
"role": "user"
},
{
"content": "They're Kane and Gretzka from Bayern Munich.",
"role": "assistant"
},
{
"content": "What are they doing?",
"role": "user"
},
{
"content": "They are celebrating on the soccer field.",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>Who is he?",
"role": "user"
},
{
"content": "He's Thomas Muller from Bayern Munich.",
"role": "assistant"
},
{
"content": "<|image_pad|>Why is he on the ground?",
"role": "user"
},
{
"content": "Because he's sliding on his knees to celebrate.",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/2.jpg","mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>Please describe this image",
"role": "user"
},
{
"content": "Chinese astronaut Gui Haichao is giving a speech.",
"role": "assistant"
},
{
"content": "What has he accomplished?",
"role": "user"
},
{
"content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/3.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>他们是谁?",
"role": "user"
},
{
"content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。",
"role": "assistant"
},
{
"content": "<|image_pad|>他们在做什么?",
"role": "user"
},
{
"content": "他们在足球场上庆祝。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/1.jpg","mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>他是谁?",
"role": "user"
},
{
"content": "他是来自拜仁慕尼黑的托马斯·穆勒。",
"role": "assistant"
},
{
"content": "他为什么在地上?",
"role": "user"
},
{
"content": "因为他正在双膝跪地滑行庆祝。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>请描述这张图片",
"role": "user"
},
{
"content": "中国宇航员桂海潮正在讲话。",
"role": "assistant"
},
{
"content": "他取得过哪些成就?",
"role": "user"
},
{
"content": "他于2022年6月被任命为神舟十六号任务的有效载荷专家,从而成为2023年5月30日进入太空的首位平民宇航员。他负责在轨操作空间科学实验有效载荷。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/3.jpg"
]
}
]
40 changes: 40 additions & 0 deletions examples/train_full/qwen2vl_full_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
### model
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
visual_inputs: true

### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json

### dataset
dataset: qwen2vl_demo
template: qwen2vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/qwen2-vl-7b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
40 changes: 40 additions & 0 deletions examples/train_lora/qwen2vl_lora_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
### model
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
visual_inputs: true

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

### dataset
dataset: qwen2vl_demo
template: qwen2vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/qwen2-vl-7b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
28 changes: 28 additions & 0 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,37 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
compute_dtype: "torch.dtype" = torch.float32

def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None
if "image_grid_thw" in features[0]:
image_grid_thw_list = [
torch.Tensor(feature["image_grid_thw"]).long()
for feature in features
if feature["image_grid_thw"][0][0] > 0
]
pixel_values_list = [
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
]
if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, 0)
else:
# Handle the case where the list is empty, for example:
image_grid_thw = None
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, 0)
else:
# Handle the case where the list is empty, for example:
pixel_values = None
features = [
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
for feature in features
]

features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if image_grid_thw is not None:
features["image_grid_thw"] = image_grid_thw
features["pixel_values"] = pixel_values

return features

Expand Down
14 changes: 14 additions & 0 deletions src/llamafactory/data/processors/processor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
return [0] * image_seq_length + [1] * (input_len - image_seq_length)


def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. support multi images
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt")
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
image_inputs["image_grid_thw"][0][0] = 0
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}


def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
Expand Down
44 changes: 39 additions & 5 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@

from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
from .processor_utils import (
get_paligemma_token_type_ids,
get_pixel_values,
get_qwen2vl_image_inputs,
greedy_knapsack,
infer_seqlen,
)


if TYPE_CHECKING:
from PIL.Image import Image as ImageObject
from transformers import PreTrainedTokenizer, ProcessorMixin

from ...hparams import DataArguments
Expand All @@ -36,13 +43,32 @@ def _encode_supervised_example(
system: Optional[str],
tools: Optional[str],
template: "Template",
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
if processor is not None and "image_grid_thw" in processor.model_input_names: # qwen2_vl models
image_processor = getattr(processor, "image_processor")
merge_length = image_processor.merge_size**2
if len(images) > 0:
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
index = 0
for message in prompt:
content = message["content"]
while "<|image_pad|>" in content:
content = content.replace(
"<|image_pad|>",
template.vision_start_token
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
+ template.vision_end_token,
1,
)
index += 1
message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
elif processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]

messages = prompt + response
Expand Down Expand Up @@ -107,6 +133,8 @@ def preprocess_supervised_dataset(
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
model_inputs["image_grid_thw"] = []

for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
Expand All @@ -118,6 +146,7 @@ def preprocess_supervised_dataset(
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
Expand All @@ -129,9 +158,14 @@ def preprocess_supervised_dataset(
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor)
model_inputs["pixel_values"].append(image_inputs["pixel_values"])
model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"])
else:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))

return model_inputs

Expand Down
Loading