Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Several MLLM Models #4136

Closed
wants to merge 46 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9abd1b8
add videollava
BUAADreamer Jun 6, 2024
ddad20f
add videollava and demo video data
BUAADreamer Jun 6, 2024
115ffbe
add videollava and demo video data
BUAADreamer Jun 6, 2024
7cdc262
fix processor conflict
BUAADreamer Jun 7, 2024
0b7535e
fix supervised conflict
BUAADreamer Jun 7, 2024
4e97a83
support video-llava
BUAADreamer Jun 7, 2024
adb3b26
support video-llava
BUAADreamer Jun 7, 2024
ef76387
Merge branch 'hiyouga:main' into main
BUAADreamer Jun 8, 2024
76c6379
add av to requirements
BUAADreamer Jun 8, 2024
3a53b3c
add llava-next/idefics2
BUAADreamer Jun 8, 2024
3188a56
support video-llava/llava-next/idefics2(4.42)
BUAADreamer Jun 8, 2024
daeffb4
modify idefics2 template
BUAADreamer Jun 8, 2024
3fc87e8
Update requirements.txt
hiyouga Jun 8, 2024
0f46edb
Merge branch 'hiyouga:main' into main
BUAADreamer Jun 8, 2024
d2e4362
modify position of idefics2 in template
BUAADreamer Jun 8, 2024
307e423
Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
BUAADreamer Jun 8, 2024
7d2a8f3
align preprocess_supervised_dataset implementation
BUAADreamer Jun 8, 2024
a0fe536
Merge branch 'hiyouga:main' into main
BUAADreamer Jun 11, 2024
7689c9d
Merge branch 'hiyouga:main' into main
BUAADreamer Jun 17, 2024
ed9e8ec
Merge branch 'hiyouga:main' into main
BUAADreamer Jun 19, 2024
7dbe875
Merge branch 'main' into main
hiyouga Jun 24, 2024
f5c0b34
Merge branch 'hiyouga:main' into main
BUAADreamer Jun 28, 2024
1d5e9d5
finetune right for video model
BUAADreamer Jun 29, 2024
a99abcd
add image_data/video_data key to template to flxibly support more MLLMs
BUAADreamer Jun 30, 2024
00a2923
Merge branch 'main' into main
BUAADreamer Jun 30, 2024
722e189
add video inference
BUAADreamer Jun 30, 2024
9092279
Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
BUAADreamer Jun 30, 2024
98a2e3d
fix some
BUAADreamer Jul 1, 2024
a571c4c
Merge branch 'hiyouga:main' into main
BUAADreamer Jul 1, 2024
d092749
Merge branch 'hiyouga:main' into main
BUAADreamer Jul 2, 2024
7d0419d
support idefics2/llava_next inference right
BUAADreamer Jul 2, 2024
e65537d
add model constants
BUAADreamer Jul 2, 2024
5023974
Merge branch 'main' into main
BUAADreamer Jul 2, 2024
d5563d3
Merge branch 'hiyouga:main' into main
BUAADreamer Jul 15, 2024
ca44c8d
solve the predict problem of llava-next-video and the multi-gpu finet…
BUAADreamer Jul 15, 2024
abdc2fa
Merge branch 'hiyouga:main' into main
BUAADreamer Jul 15, 2024
7b5b32f
Merge branch 'main' into main
BUAADreamer Jul 22, 2024
66980bf
Merge branch 'main' into main
BUAADreamer Aug 22, 2024
f033b3d
add if condition for llava-video
Kuangdd01 Aug 24, 2024
a96e29e
Merge pull request #1 from BUAADreamer/main
Kuangdd01 Aug 24, 2024
800793a
Merge branch 'main' of https://github.com/Kuangdd01/LLaMA-Factory-X
Kuangdd01 Aug 24, 2024
9eac318
fix some errors
Kuangdd01 Aug 25, 2024
e116d34
remove redundant import
Kuangdd01 Aug 25, 2024
7e59b76
Merge pull request #2 from Kuangdd01/main
BUAADreamer Aug 25, 2024
201593d
add visual model config for llava-next-video
Kuangdd01 Aug 28, 2024
24526fe
Merge pull request #3 from Kuangdd01/main
BUAADreamer Aug 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,35 @@
"assistant_tag": "assistant"
}
},
"video_demo": {
"file_name": "video_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"videos": "videos"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"visual_mix_demo": {
"file_name": "visual_mix_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"videos": "videos",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"
Expand Down
Binary file added data/mllm_demo_data/1.mp4
Binary file not shown.
Binary file added data/mllm_demo_data/2.avi
Binary file not shown.
Binary file added data/mllm_demo_data/3.mp4
Binary file not shown.
47 changes: 47 additions & 0 deletions data/video_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[
{
"messages": [
{
"content": "Why is this video funny?",
"role": "user"
},
{
"content": "Because a baby is reading, and he is so cute!",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/1.mp4"
]
},
{
"messages": [
{
"content": "What is she doing?",
"role": "user"
},
{
"content": "She is cooking",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/2.avi"
]
},
{
"messages": [
{
"content": "What's in the video?",
"role": "user"
},
{
"content": "A baby is playing in the living room",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/3.mp4"
]
}
]
56 changes: 56 additions & 0 deletions data/visual_mix_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
[
{
"messages": [
{
"content": "<video>Why is this video funny?<image>",
"role": "user"
},
{
"content": "Because a baby is reading, and he is so cute!",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/1.mp4"
],
"images": [
"mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "<video>What is she doing?<image>",
"role": "user"
},
{
"content": "She is cooking",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/2.avi"
],
"images": [
"mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "<video>Why is this video funny?<image>",
"role": "user"
},
{
"content": "A baby is playing!",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/3.mp4"
],
"images": [
"mllm_demo_data/3.jpg"
]
}
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ known-third-party = [
"peft",
"torch",
"transformers",
"trl"
"trl",
"av",
]

[tool.ruff.format]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ matplotlib>=3.7.0
fire
packaging
pyyaml
av
numpy<2.0.0
2 changes: 2 additions & 0 deletions src/llamafactory/chat/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ async def chat(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]: ...

Expand All @@ -67,6 +68,7 @@ async def stream_chat(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...

Expand Down
6 changes: 4 additions & 2 deletions src/llamafactory/chat/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ def stream_chat(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
Expand All @@ -94,9 +95,10 @@ async def astream_chat(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token

def get_scores(
Expand Down
60 changes: 49 additions & 11 deletions src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response

from ..data.processors.processor_utils import get_pixel_values, get_pixel_values_videos

if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper

from ..data import Template
Expand Down Expand Up @@ -79,30 +78,58 @@ def _process_args(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
processor_class = "" if processor is None else type(processor).__name__

visual_token_flag = template.image_token in messages[0]["content"] or \
template.video_token in messages[0]["content"]

if (
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and not processor_class == 'PaliGemmaProcessor'
and template.image_token not in messages[0]["content"]
): # llava-like models
and not visual_token_flag
):
messages[0]["content"] = template.image_token + messages[0]["content"]

if (
processor is not None
and video is not None
and template.video_token not in messages[0]["content"]
and not visual_token_flag
):
messages[0]["content"] = template.video_token + messages[0]["content"]

if processor_class == 'Idefics2Processor':
fake_image_token = processor.fake_image_token.content
image_str = f"{fake_image_token}{template.image_token * processor.image_seq_len}{fake_image_token}"
image_str = image_str * 5
for j in range(len(messages)):
content = messages[j]['content']
content = content.replace(template.image_token, image_str)
content = content.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}")
messages[j]['content'] = content

paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
pixel_values_video = None
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)

if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
pixel_values = get_pixel_values([image], processor, template.image_data_key)
if processor_class == 'PaliGemmaProcessor': # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids

if processor is not None and video is not None: # add video features
pixel_values_video = get_pixel_values_videos([video], processor, template.video_data_key)

prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
Expand Down Expand Up @@ -165,7 +192,12 @@ def _process_args(
)

if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
for key in template.image_data_key:
gen_kwargs[key] = pixel_values[key].to(model.device)

if pixel_values_video is not None:
for key in template.video_data_key:
gen_kwargs[key] = pixel_values_video[key].to(model.device)

return gen_kwargs, prompt_length

Expand All @@ -181,10 +213,11 @@ def _chat(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
Expand Down Expand Up @@ -216,10 +249,11 @@ def _stream_chat(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
Expand Down Expand Up @@ -273,6 +307,7 @@ async def chat(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
Expand All @@ -289,6 +324,7 @@ async def chat(
system,
tools,
image,
video,
input_kwargs,
)
async with self.semaphore:
Expand All @@ -301,6 +337,7 @@ async def stream_chat(
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
video: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
Expand All @@ -317,6 +354,7 @@ async def stream_chat(
system,
tools,
image,
video,
input_kwargs,
)
async with self.semaphore:
Expand Down
Loading