Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Qwen2-VL Fine-Tuning on Video Datasets #5365

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@
"assistant_tag": "assistant"
}
},
"mllm_video_demo": {
"file_name": "mllm_video_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"videos": "videos"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"
Expand Down
Binary file added data/mllm_demo_data/1.mp4
Binary file not shown.
Binary file added data/mllm_demo_data/2.avi
Binary file not shown.
Binary file added data/mllm_demo_data/3.mp4
Binary file not shown.
47 changes: 47 additions & 0 deletions data/mllm_video_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[
{
"messages": [
{
"content": "<video>Why is this video funny?",
"role": "user"
},
{
"content": "Because a baby is reading, and he is so cute!",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/1.mp4"
]
},
{
"messages": [
{
"content": "<video>What is she doing?",
"role": "user"
},
{
"content": "She is cooking.",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/2.avi"
]
},
{
"messages": [
{
"content": "<video>What's in the video?",
"role": "user"
},
{
"content": "A baby is playing in the living room.",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/3.mp4"
]
}
]
8 changes: 5 additions & 3 deletions src/llamafactory/chat/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@


if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine

from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


Expand Down Expand Up @@ -56,7 +56,8 @@ async def chat(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]: ...

Expand All @@ -66,7 +67,8 @@ async def stream_chat(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...

Expand Down
25 changes: 15 additions & 10 deletions src/llamafactory/chat/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@


if TYPE_CHECKING:
from PIL.Image import Image

from ..data.mm_plugin import ImageInput, VideoInput
from .base_engine import BaseEngine, Response


Expand Down Expand Up @@ -56,31 +55,36 @@ def chat(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
)
return task.result()

async def achat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)

def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
Expand All @@ -93,10 +97,11 @@ async def astream_chat(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token

def get_scores(
Expand Down
51 changes: 32 additions & 19 deletions src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@
from transformers import GenerationConfig, TextIteratorStreamer

from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response


if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from trl import PreTrainedModelWrapper

from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


Expand Down Expand Up @@ -78,20 +78,30 @@ def _process_args(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None:
mm_input_dict.update({"images": [image], "imglens": [1]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]

messages = template.mm_plugin.process_messages(messages, [image], processor)
if video is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]

messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
if image is not None:
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor)
prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
)

prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
Expand Down Expand Up @@ -154,13 +164,10 @@ def _process_args(
logits_processor=get_logits_processor(),
)

if image is not None:
mm_inputs = template.mm_plugin.get_mm_inputs(
images=[image], imglens=[1], seqlens=[prompt_length], processor=processor
)
for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
gen_kwargs[key] = value.to(model.device)

return gen_kwargs, prompt_length

Expand All @@ -175,11 +182,12 @@ def _chat(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
Expand Down Expand Up @@ -210,11 +218,12 @@ def _stream_chat(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
Expand Down Expand Up @@ -267,7 +276,8 @@ async def chat(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
Expand All @@ -284,6 +294,7 @@ async def chat(
system,
tools,
image,
video,
input_kwargs,
)
async with self.semaphore:
Expand All @@ -295,7 +306,8 @@ async def stream_chat(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["Image"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
Expand All @@ -312,6 +324,7 @@ async def stream_chat(
system,
tools,
image,
video,
input_kwargs,
)
async with self.semaphore:
Expand Down
27 changes: 26 additions & 1 deletion src/llamafactory/data/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers import Seq2SeqTrainingArguments

from ..hparams import DataArguments
from .mm_plugin import ImageInput
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr


Expand All @@ -52,6 +52,26 @@ def _convert_images(
return images


def _convert_videos(
videos: Sequence["VideoInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if len(videos) == 0:
return None

videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i])

return videos


def convert_alpaca(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
Expand Down Expand Up @@ -96,12 +116,14 @@ def convert_alpaca(
response = []

convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output

Expand Down Expand Up @@ -187,12 +209,14 @@ def convert_sharegpt(
prompt, response = [], []

convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output

Expand All @@ -210,6 +234,7 @@ def align_dataset(
_system: "..."
_tools: "...",
_images: [],
_videos: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
Expand Down
Loading