From 7af553ea30031446b4c1c74ad83187f9fd3de4e7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 27 Dec 2024 19:21:23 +0800 Subject: [PATCH] [Misc] Abstract the logic for reading and writing media content (#11527) Signed-off-by: DarkLight1337 --- tests/entrypoints/openai/test_serving_chat.py | 1 + tests/entrypoints/test_chat_utils.py | 6 +- tests/multimodal/test_utils.py | 59 ++- vllm/assets/audio.py | 6 +- vllm/entrypoints/chat_utils.py | 129 +++-- vllm/multimodal/audio.py | 36 +- vllm/multimodal/base.py | 38 +- vllm/multimodal/image.py | 41 +- vllm/multimodal/utils.py | 477 ++++++++---------- vllm/multimodal/video.py | 87 +++- 10 files changed, 493 insertions(+), 387 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 51b255bb2a6db..61677b65af342 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -33,6 +33,7 @@ class MockModelConfig: hf_config = MockHFConfig() logits_processor_pattern = None diff_sampling_param: Optional[dict] = None + allowed_local_media_path: str = "" def get_diff_sampling_param(self): return self.diff_sampling_param or {} diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 996e60bfee592..d63b963522e73 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -2,7 +2,6 @@ from typing import Optional import pytest -from PIL import Image from vllm.assets.image import ImageAsset from vllm.config import ModelConfig @@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input( image_data = mm_data.get("image") assert image_data is not None - if image_count == 1: - assert isinstance(image_data, Image.Image) - else: - assert isinstance(image_data, list) and len(image_data) == image_count + assert isinstance(image_data, list) and len(image_data) == image_count def test_parse_chat_messages_single_image( diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index fd82fb0c55fd7..6029f2e514772 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -9,7 +9,7 @@ from PIL import Image, ImageChops from transformers import AutoConfig, AutoTokenizer -from vllm.multimodal.utils import (async_fetch_image, fetch_image, +from vllm.multimodal.utils import (MediaConnector, repeat_and_pad_placeholder_tokens) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) @@ -23,7 +23,12 @@ @pytest.fixture(scope="module") def url_images() -> Dict[str, Image.Image]: - return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS} + connector = MediaConnector() + + return { + image_url: connector.fetch_image(image_url) + for image_url in TEST_IMAGE_URLS + } def get_supported_suffixes() -> Tuple[str, ...]: @@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool: @pytest.mark.asyncio @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_fetch_image_http(image_url: str): - image_sync = fetch_image(image_url) - image_async = await async_fetch_image(image_url) + connector = MediaConnector() + + image_sync = connector.fetch_image(image_url) + image_async = await connector.fetch_image_async(image_url) assert _image_equals(image_sync, image_async) @@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str): @pytest.mark.parametrize("suffix", get_supported_suffixes()) async def test_fetch_image_base64(url_images: Dict[str, Image.Image], image_url: str, suffix: str): + connector = MediaConnector() url_image = url_images[image_url] try: @@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], base64_image = base64.b64encode(f.read()).decode("utf-8") data_url = f"data:{mime_type};base64,{base64_image}" - data_image_sync = fetch_image(data_url) + data_image_sync = connector.fetch_image(data_url) if _image_equals(url_image, Image.open(f)): assert _image_equals(url_image, data_image_sync) else: pass # Lossy format; only check that image can be opened - data_image_async = await async_fetch_image(data_url) + data_image_async = await connector.fetch_image_async(data_url) assert _image_equals(data_image_sync, data_image_async) @pytest.mark.asyncio @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_fetch_image_local_files(image_url: str): + connector = MediaConnector() + with TemporaryDirectory() as temp_dir: - origin_image = fetch_image(image_url) + local_connector = MediaConnector(allowed_local_media_path=temp_dir) + + origin_image = connector.fetch_image(image_url) origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)), quality=100, icc_profile=origin_image.info.get('icc_profile')) - image_async = await async_fetch_image( - f"file://{temp_dir}/{os.path.basename(image_url)}", - allowed_local_media_path=temp_dir) - - image_sync = fetch_image( - f"file://{temp_dir}/{os.path.basename(image_url)}", - allowed_local_media_path=temp_dir) + image_async = await local_connector.fetch_image_async( + f"file://{temp_dir}/{os.path.basename(image_url)}") + image_sync = local_connector.fetch_image( + f"file://{temp_dir}/{os.path.basename(image_url)}") # Check that the images are equal assert not ImageChops.difference(image_sync, image_async).getbbox() - with pytest.raises(ValueError): - await async_fetch_image( - f"file://{temp_dir}/../{os.path.basename(image_url)}", - allowed_local_media_path=temp_dir) - with pytest.raises(ValueError): - await async_fetch_image( + with pytest.raises(ValueError, match="must be a subpath"): + await local_connector.fetch_image_async( + f"file://{temp_dir}/../{os.path.basename(image_url)}") + with pytest.raises(RuntimeError, match="Cannot load local files"): + await connector.fetch_image_async( f"file://{temp_dir}/../{os.path.basename(image_url)}") - with pytest.raises(ValueError): - fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}", - allowed_local_media_path=temp_dir) - with pytest.raises(ValueError): - fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}") + with pytest.raises(ValueError, match="must be a subpath"): + local_connector.fetch_image( + f"file://{temp_dir}/../{os.path.basename(image_url)}") + with pytest.raises(RuntimeError, match="Cannot load local files"): + connector.fetch_image( + f"file://{temp_dir}/../{os.path.basename(image_url)}") @pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 9033644e3264a..a46c67ad7e00e 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -21,12 +21,10 @@ class AudioAsset: name: Literal["winning_call", "mary_had_lamb"] @property - def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]: + def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", s3_prefix=ASSET_DIR) - y, sr = librosa.load(audio_path, sr=None) - assert isinstance(sr, int) - return y, sr + return librosa.load(audio_path, sr=None) @property def url(self) -> str: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 3df08c740d65b..a492d5496e025 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -6,7 +6,7 @@ from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, - Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) + Literal, Optional, Tuple, TypeVar, Union, cast) import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils @@ -23,6 +23,8 @@ ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from openai.types.chat import (ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam) +from openai.types.chat.chat_completion_content_part_input_audio_param import ( + InputAudio) # yapf: enable # pydantic needs the TypedDict from typing_extensions from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -31,11 +33,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import (async_get_and_parse_audio, - async_get_and_parse_image, - async_get_and_parse_video, - get_and_parse_audio, get_and_parse_image, - get_and_parse_video) +from vllm.multimodal.utils import MediaConnector from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import print_warning_once @@ -368,14 +366,17 @@ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): self._tokenizer = tokenizer self._allowed_items = (model_config.multimodal_config.limit_per_prompt if model_config.multimodal_config else {}) - self._consumed_items = {k: 0 for k in self._allowed_items} - self._items: List[_T] = [] + self._items_by_modality = defaultdict[str, list[_T]](list) @property def model_config(self) -> ModelConfig: return self._model_config + @property + def allowed_local_media_path(self): + return self._model_config.allowed_local_media_path + @staticmethod @lru_cache(maxsize=None) def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: @@ -435,38 +436,19 @@ def _placeholder_str(self, modality: ModalityStr, else: raise TypeError(f"Unknown modality: {modality}") - @staticmethod - def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict: - mm_lists: Mapping[str, List[object]] = defaultdict(list) - - # Merge all the multi-modal items - for single_mm_data in items: - for mm_key, mm_item in single_mm_data.items(): - if isinstance(mm_item, list): - mm_lists[mm_key].extend(mm_item) - else: - mm_lists[mm_key].append(mm_item) - - # Unpack any single item lists for models that don't expect multiple. - return { - mm_key: mm_list[0] if len(mm_list) == 1 else mm_list - for mm_key, mm_list in mm_lists.items() - } - def add(self, modality: ModalityStr, item: _T) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. """ allowed_count = self._allowed_items.get(modality, 1) - current_count = self._consumed_items.get(modality, 0) + 1 + current_count = len(self._items_by_modality[modality]) + 1 if current_count > allowed_count: raise ValueError( f"At most {allowed_count} {modality}(s) may be provided in " "one request.") - self._consumed_items[modality] = current_count - self._items.append(item) + self._items_by_modality[modality].append(item) return self._placeholder_str(modality, current_count) @@ -475,22 +457,26 @@ def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError -class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): +class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> Optional[MultiModalDataDict]: - return self._combine(self._items) if self._items else None + if self._items_by_modality: + return dict(self._items_by_modality) + + return None def create_parser(self) -> "BaseMultiModalContentParser": return MultiModalContentParser(self) -class AsyncMultiModalItemTracker( - BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]): +class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> Optional[MultiModalDataDict]: - if self._items: - items = await asyncio.gather(*self._items) - return self._combine(items) + if self._items_by_modality: + return { + modality: await asyncio.gather(*items) + for modality, items in self._items_by_modality.items() + } return None @@ -522,7 +508,7 @@ def parse_audio(self, audio_url: str) -> None: raise NotImplementedError @abstractmethod - def parse_input_audio(self, input_audio: Dict[str, str]) -> None: + def parse_input_audio(self, input_audio: InputAudio) -> None: raise NotImplementedError @abstractmethod @@ -537,31 +523,31 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: self._tracker = tracker + self._connector = MediaConnector( + allowed_local_media_path=tracker.allowed_local_media_path, + ) + def parse_image(self, image_url: str) -> None: - image = get_and_parse_image(image_url, - allowed_local_media_path=self._tracker. - _model_config.allowed_local_media_path) + image = self._connector.fetch_image(image_url) placeholder = self._tracker.add("image", image) self._add_placeholder(placeholder) def parse_audio(self, audio_url: str) -> None: - audio = get_and_parse_audio(audio_url) + audio = self._connector.fetch_audio(audio_url) placeholder = self._tracker.add("audio", audio) self._add_placeholder(placeholder) - def parse_input_audio(self, input_audio: Dict[str, str]) -> None: - input_audio_data = input_audio.get("data","") - input_audio_format = input_audio.get("format","") - audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" - audio = get_and_parse_audio(audio_url) + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" - placeholder = self._tracker.add("audio", audio) - self._add_placeholder(placeholder) + return self.parse_audio(audio_url) def parse_video(self, video_url: str) -> None: - video = get_and_parse_video(video_url) + video = self._connector.fetch_video(video_url) placeholder = self._tracker.add("video", video) self._add_placeholder(placeholder) @@ -573,33 +559,31 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: super().__init__() self._tracker = tracker + self._connector = MediaConnector( + allowed_local_media_path=tracker.allowed_local_media_path, + ) def parse_image(self, image_url: str) -> None: - image_coro = async_get_and_parse_image( - image_url, - allowed_local_media_path=self._tracker._model_config. - allowed_local_media_path) + image_coro = self._connector.fetch_image_async(image_url) placeholder = self._tracker.add("image", image_coro) self._add_placeholder(placeholder) def parse_audio(self, audio_url: str) -> None: - audio_coro = async_get_and_parse_audio(audio_url) + audio_coro = self._connector.fetch_audio_async(audio_url) placeholder = self._tracker.add("audio", audio_coro) self._add_placeholder(placeholder) - def parse_input_audio(self, input_audio: Dict[str, str]) -> None: - input_audio_data = input_audio.get("data","") - input_audio_format = input_audio.get("format","") - audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" - audio_coro = async_get_and_parse_audio(audio_url) + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" - placeholder = self._tracker.add("audio", audio_coro) - self._add_placeholder(placeholder) + return self.parse_audio(audio_url) def parse_video(self, video_url: str) -> None: - video = async_get_and_parse_video(video_url) + video = self._connector.fetch_video_async(video_url) placeholder = self._tracker.add("video", video) self._add_placeholder(placeholder) @@ -695,10 +679,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) +_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio] + # Define a mapping from part types to their corresponding parsing functions. -MM_PARSER_MAP: Dict[str, - Callable[[ChatCompletionContentPartParam], - Union[str, Dict[str,str]]]] = { +MM_PARSER_MAP: Dict[ + str, + Callable[[ChatCompletionContentPartParam], _ContentPart], +] = { "text": lambda part: _TextParser(part).get("text", ""), "image_url": @@ -715,8 +702,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], def _parse_chat_message_content_mm_part( - part: ChatCompletionContentPartParam) -> Tuple[str, - Union[str, Dict[str, str]]]: + part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: """ Parses a given multi-modal content part based on its type. @@ -783,7 +769,7 @@ def _parse_chat_message_content_parts( *, wrap_dicts: bool, ) -> List[ConversationMessage]: - content: List[Union[str, Dict[str, str]]] = [] + content = list[_ContentPart]() mm_parser = mm_tracker.create_parser() @@ -814,7 +800,7 @@ def _parse_chat_message_content_part( mm_parser: BaseMultiModalContentParser, *, wrap_dicts: bool, -) -> Optional[Union[str, Dict[str, str]]]: +) -> Optional[_ContentPart]: """Parses a single part of a conversation. If wrap_dicts is True, structured dictionary pieces for texts and images will be wrapped in dictionaries, i.e., {"type": "text", "text", ...} and @@ -823,8 +809,7 @@ def _parse_chat_message_content_part( with multimodal placeholders. """ if isinstance(part, str): # Handle plain text parts - text = _TextParser(part) - return text + return part # Handle structured dictionary parts part_type, content = _parse_chat_message_content_mm_part(part) @@ -855,7 +840,7 @@ def _parse_chat_message_content_part( return {'type': 'audio'} if wrap_dicts else None if part_type == "input_audio": - dict_content = cast(Dict[str, str], content) + dict_content = cast(InputAudio, content) mm_parser.parse_input_audio(dict_content) return {'type': 'audio'} if wrap_dicts else None diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index ed3bb82bf0aaa..3e09ef1fcbb56 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,10 +1,14 @@ +import base64 +from io import BytesIO +from pathlib import Path + import numpy as np import numpy.typing as npt from vllm.inputs.registry import InputContext from vllm.utils import PlaceholderModule -from .base import MultiModalPlugin +from .base import MediaIO, MultiModalPlugin from .inputs import AudioItem, MultiModalData, MultiModalKwargs try: @@ -12,6 +16,11 @@ except ImportError: librosa = PlaceholderModule("librosa") # type: ignore[assignment] +try: + import soundfile +except ImportError: + soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] + class AudioPlugin(MultiModalPlugin): """Plugin for audio data.""" @@ -39,3 +48,28 @@ def resample_audio( target_sr: float, ) -> npt.NDArray[np.floating]: return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) + + +class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): + + def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: + return librosa.load(BytesIO(data), sr=None) + + def load_base64( + self, + media_type: str, + data: str, + ) -> tuple[npt.NDArray, float]: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: + return librosa.load(filepath, sr=None) + + def encode_base64(self, media: tuple[npt.NDArray, float]) -> str: + audio, sr = media + + with BytesIO() as buffer: + soundfile.write(buffer, audio, sr, format="WAV") + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 1e5a46946c6c0..10488e24b30cc 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, +from pathlib import Path +from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, Optional, Sequence, Tuple, Type, TypeVar, Union) from torch import nn @@ -118,7 +119,7 @@ def map_input( self, model_config: "ModelConfig", data: MultiModalData[Any], - mm_processor_kwargs: Optional[Dict[str, Any]], + mm_processor_kwargs: Optional[dict[str, Any]], ) -> MultiModalKwargs: """ Transform the data into a dictionary of model inputs using the @@ -254,10 +255,10 @@ class MultiModalPlaceholderMap: """ class IndexMap(NamedTuple): - src: List[int] - dest: List[int] + src: list[int] + dest: list[int] - src_ranges: List[range] + src_ranges: list[range] """ The indices of the multi-modal embeddings that will replace the corresponding placeholder embeddings pointed to by ``dest_ranges``. @@ -268,7 +269,7 @@ class IndexMap(NamedTuple): The total number of flattened multi-modal embeddings. """ - dest_ranges: List[range] + dest_ranges: list[range] """ The indices of the placeholder embeddings that will be replaced by the multimodal embeddings. @@ -288,7 +289,7 @@ def __init__(self): @classmethod def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> Tuple[Optional[MultiModalDataDict], Dict[str, + ) -> Tuple[Optional[MultiModalDataDict], dict[str, "MultiModalPlaceholderMap"]]: """ Returns the multi-modal items that intersect with the portion of a @@ -376,9 +377,9 @@ def from_seq_group( def append_items_from_seq_group( self, positions: range, - multi_modal_items: List[_T], + multi_modal_items: list[_T], multi_modal_placeholders: Sequence[PlaceholderRange], - ) -> List[_T]: + ) -> list[_T]: """ Adds the multi-modal items that intersect ```positions`` to this placeholder map and returns the intersecting items. @@ -454,3 +455,22 @@ def index_map(self) -> "IndexMap": return MultiModalPlaceholderMap.IndexMap(src=src_indices, dest=dest_indices) + + +class MediaIO(ABC, Generic[_T]): + + @abstractmethod + def load_bytes(self, data: bytes) -> _T: + raise NotImplementedError + + @abstractmethod + def load_base64(self, media_type: str, data: str) -> _T: + """ + List of media types: + https://www.iana.org/assignments/media-types/media-types.xhtml + """ + raise NotImplementedError + + @abstractmethod + def load_file(self, filepath: Path) -> _T: + raise NotImplementedError diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index c705e1a3d1554..14c79dfadec0c 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,4 +1,7 @@ +import base64 from functools import lru_cache +from io import BytesIO +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional import torch @@ -9,7 +12,7 @@ from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of -from .base import MultiModalPlugin +from .base import MediaIO, MultiModalPlugin from .inputs import ImageItem, MultiModalData, MultiModalKwargs if TYPE_CHECKING: @@ -96,3 +99,39 @@ def rescale_image_size(image: Image.Image, if transpose >= 0: image = image.transpose(Image.Transpose(transpose)) return image + + +class ImageMediaIO(MediaIO[Image.Image]): + + def __init__(self, *, image_mode: str = "RGB") -> None: + super().__init__() + + self.image_mode = image_mode + + def load_bytes(self, data: bytes) -> Image.Image: + image = Image.open(BytesIO(data)) + image.load() + return image.convert(self.image_mode) + + def load_base64(self, media_type: str, data: str) -> Image.Image: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> Image.Image: + image = Image.open(filepath) + image.load() + return image.convert(self.image_mode) + + def encode_base64( + self, + media: Image.Image, + *, + image_format: str = "JPEG", + ) -> str: + image = media + + with BytesIO() as buffer: + image = image.convert(self.image_mode) + image.save(buffer, image_format) + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index a49da2bdee972..87b12a6fb33c1 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,8 +1,7 @@ -import base64 -import os from functools import lru_cache -from io import BytesIO -from typing import List, Optional, Tuple, TypeVar, Union +from pathlib import Path +from typing import Optional, TypeVar, Union +from urllib.parse import ParseResult, urlparse import numpy as np import numpy.typing as npt @@ -10,283 +9,246 @@ from PIL import Image import vllm.envs as envs -from vllm.connections import global_http_connection +from vllm.connections import HTTPConnection, global_http_connection from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -from vllm.utils import PlaceholderModule -from .inputs import MultiModalDataDict, PlaceholderRange - -try: - import decord -except ImportError: - decord = PlaceholderModule("decord") # type: ignore[assignment] - -try: - import librosa -except ImportError: - librosa = PlaceholderModule("librosa") # type: ignore[assignment] - -try: - import soundfile -except ImportError: - soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] +from .audio import AudioMediaIO +from .base import MediaIO +from .image import ImageMediaIO +from .inputs import PlaceholderRange +from .video import VideoMediaIO logger = init_logger(__name__) cached_get_tokenizer = lru_cache(get_tokenizer) +_M = TypeVar("_M") -def _load_image_from_bytes(b: bytes) -> Image.Image: - image = Image.open(BytesIO(b)) - image.load() - return image - - -def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool: - # Get the common path - common_path = os.path.commonpath([ - os.path.abspath(image_path), - os.path.abspath(allowed_local_media_path) - ]) - # Check if the common path is the same as allowed_local_media_path - return common_path == os.path.abspath(allowed_local_media_path) +class MediaConnector: -def _load_image_from_file(image_url: str, - allowed_local_media_path: str) -> Image.Image: - if not allowed_local_media_path: - raise ValueError("Invalid 'image_url': Cannot load local files without" - "'--allowed-local-media-path'.") - if allowed_local_media_path: - if not os.path.exists(allowed_local_media_path): - raise ValueError( - "Invalid '--allowed-local-media-path': " - f"The path {allowed_local_media_path} does not exist.") - if not os.path.isdir(allowed_local_media_path): + def __init__( + self, + connection: HTTPConnection = global_http_connection, + *, + allowed_local_media_path: str = "", + ) -> None: + super().__init__() + + self.connection = connection + + if allowed_local_media_path: + allowed_local_media_path_ = Path(allowed_local_media_path) + + if not allowed_local_media_path_.exists(): + raise ValueError( + "Invalid `--allowed-local-media-path`: The path " + f"{allowed_local_media_path_} does not exist.") + if not allowed_local_media_path_.is_dir(): + raise ValueError( + "Invalid `--allowed-local-media-path`: The path " + f"{allowed_local_media_path_} must be a directory.") + else: + allowed_local_media_path_ = None + + self.allowed_local_media_path = allowed_local_media_path_ + + def _load_data_url( + self, + url_spec: ParseResult, + media_io: MediaIO[_M], + ) -> _M: + data_spec, data = url_spec.path.split(",", 1) + media_type, data_type = data_spec.split(";", 1) + + if data_type != "base64": + msg = "Only base64 data URLs are supported for now." + raise NotImplementedError(msg) + + return media_io.load_base64(media_type, data) + + def _load_file_url( + self, + url_spec: ParseResult, + media_io: MediaIO[_M], + ) -> _M: + allowed_local_media_path = self.allowed_local_media_path + if allowed_local_media_path is None: + raise RuntimeError("Cannot load local files without " + "`--allowed-local-media-path`.") + + filepath = Path(url_spec.path) + if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( - "Invalid '--allowed-local-media-path': " - f"The path {allowed_local_media_path} must be a directory.") - - # Only split once and assume the second part is the image path - _, image_path = image_url.split("file://", 1) - if not _is_subpath(image_path, allowed_local_media_path): - raise ValueError( - f"Invalid 'image_url': The file path {image_path} must" - " be a subpath of '--allowed-local-media-path'" - f" '{allowed_local_media_path}'.") - - image = Image.open(image_path) - image.load() - return image + f"The file path {filepath} must be a subpath " + f"of `--allowed-local-media-path` {allowed_local_media_path}.") + return media_io.load_file(filepath) -def _load_image_from_data_url(image_url: str) -> Image.Image: - # Only split once and assume the second part is the base64 encoded image - _, image_base64 = image_url.split(",", 1) - return load_image_from_base64(image_base64) - - -def fetch_image(image_url: str, - *, - image_mode: str = "RGB", - allowed_local_media_path: str = "") -> Image.Image: - """ - Load a PIL image from a HTTP or base64 data URL. - - By default, the image is converted into RGB format. - """ - if image_url.startswith('http'): - image_raw = global_http_connection.get_bytes( - image_url, - timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, - ) - image = _load_image_from_bytes(image_raw) - - elif image_url.startswith('data:image'): - image = _load_image_from_data_url(image_url) - elif image_url.startswith('file://'): - image = _load_image_from_file(image_url, allowed_local_media_path) - else: - raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image', 'file://' or 'http'.") - - return image.convert(image_mode) - - -async def async_fetch_image(image_url: str, - *, - image_mode: str = "RGB", - allowed_local_media_path: str = "") -> Image.Image: - """ - Asynchronously load a PIL image from a HTTP or base64 data URL. - - By default, the image is converted into RGB format. - """ - if image_url.startswith('http'): - image_raw = await global_http_connection.async_get_bytes( - image_url, - timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, - ) - image = _load_image_from_bytes(image_raw) - - elif image_url.startswith('data:image'): - image = _load_image_from_data_url(image_url) - elif image_url.startswith('file://'): - image = _load_image_from_file(image_url, allowed_local_media_path) - else: - raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image', 'file://' or 'http'.") + def load_from_url( + self, + url: str, + media_io: MediaIO[_M], + *, + fetch_timeout: Optional[int] = None, + ) -> _M: + url_spec = urlparse(url) - return image.convert(image_mode) + if url_spec.scheme.startswith("http"): + connection = self.connection + data = connection.get_bytes(url, timeout=fetch_timeout) + return media_io.load_bytes(data) -def _load_video_from_bytes(b: bytes, num_frames: int = 32) -> npt.NDArray: - video_path = BytesIO(b) - vr = decord.VideoReader(video_path, num_threads=1) - total_frame_num = len(vr) + if url_spec.scheme == "data": + return self._load_data_url(url_spec, media_io) - if total_frame_num > num_frames: - uniform_sampled_frames = np.linspace(0, - total_frame_num - 1, - num_frames, - dtype=int) - frame_idx = uniform_sampled_frames.tolist() - else: - frame_idx = [i for i in range(0, total_frame_num)] - frames = vr.get_batch(frame_idx).asnumpy() + if url_spec.scheme == "file": + return self._load_file_url(url_spec, media_io) - return frames + msg = "The URL must be either a HTTP, data or file URL." + raise ValueError(msg) + async def load_from_url_async( + self, + url: str, + media_io: MediaIO[_M], + *, + fetch_timeout: Optional[int] = None, + ) -> _M: + url_spec = urlparse(url) -def _load_video_from_data_url(video_url: str) -> npt.NDArray: - # Only split once and assume the second part is the base64 encoded video - _, video_base64 = video_url.split(",", 1) + if url_spec.scheme.startswith("http"): + connection = self.connection + data = await connection.async_get_bytes(url, timeout=fetch_timeout) - if video_url.startswith("data:video/jpeg;"): - return np.stack([ - np.array(load_image_from_base64(frame_base64)) - for frame_base64 in video_base64.split(",") - ]) + return media_io.load_bytes(data) - return load_video_from_base64(video_base64) + if url_spec.scheme == "data": + return self._load_data_url(url_spec, media_io) + if url_spec.scheme == "file": + return self._load_file_url(url_spec, media_io) -def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray: - """ - Load video from a HTTP or base64 data URL. - """ - if video_url.startswith('http') or video_url.startswith('https'): - video_raw = global_http_connection.get_bytes( - video_url, - timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, - ) - video = _load_video_from_bytes(video_raw, num_frames) - elif video_url.startswith('data:video'): - video = _load_video_from_data_url(video_url) - else: - raise ValueError("Invalid 'video_url': A valid 'video_url' must start " - "with either 'data:video' or 'http'.") - return video + msg = "The URL must be either a HTTP, data or file URL." + raise ValueError(msg) + def fetch_audio( + self, + audio_url: str, + ) -> tuple[np.ndarray, Union[int, float]]: + """ + Load audio from a URL. + """ + audio_io = AudioMediaIO() -async def async_fetch_video(video_url: str, - *, - num_frames: int = 32) -> npt.NDArray: - """ - Asynchronously load video from a HTTP or base64 data URL. - - By default, the image is converted into RGB format. - """ - if video_url.startswith('http') or video_url.startswith('https'): - video_raw = await global_http_connection.async_get_bytes( - video_url, - timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, - ) - video = _load_video_from_bytes(video_raw, num_frames) - elif video_url.startswith('data:video'): - video = _load_video_from_data_url(video_url) - else: - raise ValueError("Invalid 'video_url': A valid 'video_url' must start " - "with either 'data:video' or 'http'.") - return video - - -def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: - """ - Load audio from a URL. - """ - if audio_url.startswith("http"): - audio_bytes = global_http_connection.get_bytes( + return self.load_from_url( audio_url, - timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, + audio_io, + fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, ) - elif audio_url.startswith("data:audio"): - _, audio_base64 = audio_url.split(",", 1) - audio_bytes = base64.b64decode(audio_base64) - else: - raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start " - "with either 'data:audio' or 'http'.") - - return librosa.load(BytesIO(audio_bytes), sr=None) + async def fetch_audio_async( + self, + audio_url: str, + ) -> tuple[np.ndarray, Union[int, float]]: + """ + Asynchronously fetch audio from a URL. + """ + audio_io = AudioMediaIO() -async def async_fetch_audio( - audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: - """ - Asynchronously fetch audio from a URL. - """ - if audio_url.startswith("http"): - audio_bytes = await global_http_connection.async_get_bytes( + return await self.load_from_url_async( audio_url, - timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, + audio_io, + fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, ) - elif audio_url.startswith("data:audio"): - _, audio_base64 = audio_url.split(",", 1) - audio_bytes = base64.b64decode(audio_base64) - else: - raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start " - "with either 'data:audio' or 'http'.") - - return librosa.load(BytesIO(audio_bytes), sr=None) + def fetch_image( + self, + image_url: str, + *, + image_mode: str = "RGB", + ) -> Image.Image: + """ + Load a PIL image from a HTTP or base64 data URL. -def get_and_parse_audio(audio_url: str) -> MultiModalDataDict: - audio, sr = fetch_audio(audio_url) - return {"audio": (audio, sr)} + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + return self.load_from_url( + image_url, + image_io, + fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) -def get_and_parse_image( + async def fetch_image_async( + self, image_url: str, *, - allowed_local_media_path: str = "") -> MultiModalDataDict: - image = fetch_image(image_url, - allowed_local_media_path=allowed_local_media_path) - return {"image": image} - + image_mode: str = "RGB", + ) -> Image.Image: + """ + Asynchronously load a PIL image from a HTTP or base64 data URL. -def get_and_parse_video(video_url: str) -> MultiModalDataDict: - video = fetch_video(video_url) - return {"video": video} + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + return await self.load_from_url_async( + image_url, + image_io, + fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) -async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: - audio, sr = await async_fetch_audio(audio_url) - return {"audio": (audio, sr)} - + def fetch_video( + self, + video_url: str, + *, + image_mode: str = "RGB", + num_frames: int = 32, + ) -> npt.NDArray: + """ + Load video from a HTTP or base64 data URL. + """ + image_io = ImageMediaIO(image_mode=image_mode) + video_io = VideoMediaIO(image_io, num_frames=num_frames) + + return self.load_from_url( + video_url, + video_io, + fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, + ) -async def async_get_and_parse_image( - image_url: str, + async def fetch_video_async( + self, + video_url: str, *, - allowed_local_media_path: str = "") -> MultiModalDataDict: - image = await async_fetch_image( - image_url, allowed_local_media_path=allowed_local_media_path) - return {"image": image} + image_mode: str = "RGB", + num_frames: int = 32, + ) -> npt.NDArray: + """ + Asynchronously load video from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + video_io = VideoMediaIO(image_io, num_frames=num_frames) + + return await self.load_from_url_async( + video_url, + video_io, + fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, + ) -async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict: - video = await async_fetch_video(video_url) - return {"video": video} +global_media_connector = MediaConnector() +"""The global :class:`MediaConnector` instance used by vLLM.""" + +fetch_audio = global_media_connector.fetch_audio +fetch_image = global_media_connector.fetch_image +fetch_video = global_media_connector.fetch_video def encode_audio_base64( @@ -294,10 +256,8 @@ def encode_audio_base64( sampling_rate: int, ) -> str: """Encode audio as base64.""" - buffered = BytesIO() - soundfile.write(buffered, audio, sampling_rate, format="WAV") - - return base64.b64encode(buffered.getvalue()).decode('utf-8') + audio_io = AudioMediaIO() + return audio_io.encode_base64((audio, sampling_rate)) def encode_image_base64( @@ -311,29 +271,14 @@ def encode_image_base64( By default, the image is converted into RGB format before being encoded. """ - buffered = BytesIO() - image = image.convert(image_mode) - image.save(buffered, format) - return base64.b64encode(buffered.getvalue()).decode('utf-8') - - -def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: - """Load image from base64 format.""" - return _load_image_from_bytes(base64.b64decode(image)) + image_io = ImageMediaIO(image_mode=image_mode) + return image_io.encode_base64(image, image_format=format) def encode_video_base64(frames: npt.NDArray) -> str: - base64_frames = [] - frames_list = [frames[i] for i in range(frames.shape[0])] - for frame in frames_list: - img_base64 = encode_image_base64(Image.fromarray(frame)) - base64_frames.append(img_base64) - return ",".join(base64_frames) - - -def load_video_from_base64(video: Union[bytes, str]) -> npt.NDArray: - """Load video from base64 format.""" - return _load_video_from_bytes(base64.b64decode(video)) + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io) + return video_io.encode_base64(frames) def resolve_visual_encoder_outputs( @@ -389,7 +334,7 @@ def repeat_and_pad_token( repeat_count: int = 1, pad_token_left: Optional[_T] = None, pad_token_right: Optional[_T] = None, -) -> List[_T]: +) -> list[_T]: replacement = [token] * repeat_count if pad_token_left is not None: replacement = [pad_token_left] + replacement @@ -402,13 +347,13 @@ def repeat_and_pad_token( def repeat_and_pad_placeholder_tokens( tokenizer: AnyTokenizer, prompt: Optional[str], - prompt_token_ids: List[int], + prompt_token_ids: list[int], *, placeholder_token_id: int, - repeat_count: Union[int, List[int]], + repeat_count: Union[int, list[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, -) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]: +) -> tuple[Optional[str], list[int], list[PlaceholderRange]]: if isinstance(repeat_count, int): repeat_count = [repeat_count] @@ -450,8 +395,8 @@ def repeat_and_pad_placeholder_tokens( new_prompt += prompt_parts[i] + replacement_str new_prompt += prompt_parts[-1] - new_token_ids: List[int] = [] - placeholder_ranges: List[PlaceholderRange] = [] + new_token_ids = list[int]() + placeholder_ranges = list[PlaceholderRange]() placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: @@ -481,7 +426,7 @@ def repeat_and_pad_placeholder_tokens( def consecutive_placeholder_ranges( num_items: int, item_size: int, - initial_offset: int = 0) -> List[PlaceholderRange]: + initial_offset: int = 0) -> list[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size""" return [ diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index c4be100562703..b7d43c830cc46 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,23 +1,32 @@ -from functools import lru_cache +import base64 +from functools import lru_cache, partial +from io import BytesIO +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional import cv2 import numpy as np import numpy.typing as npt +from PIL import Image from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import is_list_of +from vllm.utils import PlaceholderModule, is_list_of -from .base import MultiModalData -from .image import ImagePlugin +from .base import MediaIO, MultiModalData +from .image import ImageMediaIO, ImagePlugin from .inputs import MultiModalKwargs, VideoItem if TYPE_CHECKING: from vllm.config import ModelConfig +try: + import decord +except ImportError: + decord = PlaceholderModule("decord") # type: ignore[assignment] + logger = init_logger(__name__) cached_get_video_processor = lru_cache(get_video_processor) @@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray, frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) sampled_frames = frames[frame_indices, ...] return sampled_frames + + +class VideoMediaIO(MediaIO[npt.NDArray]): + + def __init__( + self, + image_io: ImageMediaIO, + *, + num_frames: int = 32, + ) -> None: + super().__init__() + + self.image_io = image_io + self.num_frames = num_frames + + def load_bytes(self, data: bytes) -> npt.NDArray: + vr = decord.VideoReader(BytesIO(data), num_threads=1) + total_frame_num = len(vr) + + num_frames = self.num_frames + if total_frame_num > num_frames: + uniform_sampled_frames = np.linspace(0, + total_frame_num - 1, + num_frames, + dtype=int) + frame_idx = uniform_sampled_frames.tolist() + else: + frame_idx = list(range(0, total_frame_num)) + + return vr.get_batch(frame_idx).asnumpy() + + def load_base64(self, media_type: str, data: str) -> npt.NDArray: + if media_type.lower() == "video/jpeg": + load_frame = partial( + self.image_io.load_base64, + "image/jpeg", + ) + + return np.stack([ + np.array(load_frame(frame_data)) + for frame_data in data.split(",") + ]) + + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> npt.NDArray: + with filepath.open("rb") as f: + data = f.read() + + return self.load_bytes(data) + + def encode_base64( + self, + media: npt.NDArray, + *, + video_format: str = "JPEG", + ) -> str: + video = media + + if video_format == "JPEG": + encode_frame = partial( + self.image_io.encode_base64, + image_format=video_format, + ) + + return ",".join( + encode_frame(Image.fromarray(frame)) for frame in video) + + msg = "Only JPEG format is supported for now." + raise NotImplementedError(msg)