Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Abstract out the logic for reading and writing media content #11527

Merged
merged 10 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions vllm/assets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@ class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]

@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
y, sr = librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr
return librosa.load(audio_path, sr=None)

@property
def url(self) -> str:
Expand Down
163 changes: 91 additions & 72 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
Literal, Optional, Tuple, TypeVar, Union, cast)

import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
Expand All @@ -23,19 +23,21 @@
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Required, TypeAlias, TypedDict

import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.multimodal.audio import AudioMediaIO
from vllm.multimodal.image import ImageMediaIO
from vllm.multimodal.utils import MediaConnector
from vllm.multimodal.video import VideoMediaIO
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once

Expand Down Expand Up @@ -368,14 +370,17 @@ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}

self._items: List[_T] = []
self._items_by_modality = defaultdict[str, list[_T]](list)

@property
def model_config(self) -> ModelConfig:
return self._model_config

@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path

@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
Expand Down Expand Up @@ -435,38 +440,19 @@ def _placeholder_str(self, modality: ModalityStr,
else:
raise TypeError(f"Unknown modality: {modality}")

@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)

# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)

# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}

def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")

self._consumed_items[modality] = current_count
self._items.append(item)
self._items_by_modality[modality].append(item)

return self._placeholder_str(modality, current_count)

Expand All @@ -475,22 +461,26 @@ def create_parser(self) -> "BaseMultiModalContentParser":
raise NotImplementedError


class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):

def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None
if self._items_by_modality:
return dict(self._items_by_modality)

return None

def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)


class AsyncMultiModalItemTracker(
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):

async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items:
items = await asyncio.gather(*self._items)
return self._combine(items)
if self._items_by_modality:
return {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}

return None

Expand Down Expand Up @@ -522,7 +512,7 @@ def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError

@abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
def parse_input_audio(self, input_audio: InputAudio) -> None:
raise NotImplementedError

@abstractmethod
Expand All @@ -537,31 +527,46 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:

self._tracker = tracker

self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
self._get_image = partial(
self._connector.load_from_url,
media_io=ImageMediaIO(),
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
self._get_audio = partial(
self._connector.load_from_url,
media_io=AudioMediaIO(),
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
self._get_video = partial(
self._connector.load_from_url,
media_io=VideoMediaIO(ImageMediaIO()),
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)

def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url,
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)
image = self._get_image(image_url)

placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)

def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url)
audio = self._get_audio(audio_url)

placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)

def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio = get_and_parse_audio(audio_url)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"

placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
return self.parse_audio(audio_url)

def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url)
video = self._get_video(video_url)

placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
Expand All @@ -573,33 +578,46 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
super().__init__()

self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
self._get_image_async = partial(
self._connector.load_from_url_async,
media_io=ImageMediaIO(),
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
self._get_audio_async = partial(
self._connector.load_from_url_async,
media_io=AudioMediaIO(),
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
self._get_video_async = partial(
self._connector.load_from_url_async,
media_io=VideoMediaIO(ImageMediaIO()),
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)

def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)
image_coro = self._get_image_async(image_url)

placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)

def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url)
audio_coro = self._get_audio_async(audio_url)

placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)

def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"

placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
return self.parse_audio(audio_url)

def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url)
video = self._get_video_async(video_url)

placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
Expand Down Expand Up @@ -695,10 +713,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)

_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]

# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str,
Callable[[ChatCompletionContentPartParam],
Union[str, Dict[str,str]]]] = {
MM_PARSER_MAP: Dict[
str,
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
Expand All @@ -715,8 +736,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],


def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str,
Union[str, Dict[str, str]]]:
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
"""
Parses a given multi-modal content part based on its type.

Expand Down Expand Up @@ -783,7 +803,7 @@ def _parse_chat_message_content_parts(
*,
wrap_dicts: bool,
) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = []
content = list[_ContentPart]()

mm_parser = mm_tracker.create_parser()

Expand Down Expand Up @@ -814,7 +834,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser,
*,
wrap_dicts: bool,
) -> Optional[Union[str, Dict[str, str]]]:
) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
Expand All @@ -823,8 +843,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text
return part

# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
Expand Down Expand Up @@ -855,7 +874,7 @@ def _parse_chat_message_content_part(
return {'type': 'audio'} if wrap_dicts else None

if part_type == "input_audio":
dict_content = cast(Dict[str, str], content)
dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None

Expand Down
36 changes: 35 additions & 1 deletion vllm/multimodal/audio.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
import base64
from io import BytesIO
from pathlib import Path

import numpy as np
import numpy.typing as npt

from vllm.inputs.registry import InputContext
from vllm.utils import PlaceholderModule

from .base import MultiModalPlugin
from .base import MediaIO, MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs

try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]

try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]


class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data."""
Expand Down Expand Up @@ -39,3 +48,28 @@ def resample_audio(
target_sr: float,
) -> npt.NDArray[np.floating]:
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)


class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):

def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)

def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))

def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)

def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
audio, sr = media

with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format="WAV")
data = buffer.getvalue()

return base64.b64encode(data).decode('utf-8')
Loading
Loading