Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM][Bugfix] Multi-modal processor compatible with V1 multi-input #11674

Merged
merged 3 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 110 additions & 136 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
final)

import numpy as np
import torch
Expand All @@ -11,7 +12,7 @@
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias

from vllm.utils import JSONTree, is_list_of, json_map_leaves
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves

_T = TypeVar("_T")

Expand Down Expand Up @@ -160,11 +161,8 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:


@dataclass(frozen=True)
class MultiModalFieldItem:
"""
Contains metadata and data in :class:`MultiModalKwargs`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
class MultiModalFieldElem:
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
field: "BaseMultiModalField"
data: NestedTensors

Expand All @@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError

def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
return MultiModalFieldItem(self, data)
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
return MultiModalFieldElem(self, data)

def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
"""Merge multiple instances of :class:`MultiModalFieldItem` together."""
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
fields = [item.field for item in batch]
if len(set(fields)) > 1:
raise ValueError(f"Cannot merge different {fields=}")

data = self._reduce_data([item.data for item in batch])

return self._build_item(data)
return self._build_elem(data)


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an item is obtained by
directly indexing into the first dimension of the underlying data.
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by indexing into the first dimension of the underlying data.
"""

def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
return [self._build_item(item) for item in batch]
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
return [self._build_elem(item) for item in batch]

def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape
if all(item.shape == first_shape for item in batch):
if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch)

return batch
Expand All @@ -222,24 +220,24 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an item is obtained by
slicing along the first dimension of the underlying data.
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by slicing along the first dimension of the underlying data.
"""

def build_items(
def build_elems(
self,
batch: NestedTensors,
slices: Sequence[slice],
) -> list[MultiModalFieldItem]:
return [self._build_item(batch[slice_]) for slice_ in slices]
) -> list[MultiModalFieldElem]:
return [self._build_elem(batch[slice_]) for slice_ in slices]

def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape
if all(item.shape[1:] == first_shape[1:] for item in batch):
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch)

return [elem for item in batch for elem in item]
return [e for elem in batch for e in elem]


class MultiModalFieldConfig:
Expand Down Expand Up @@ -267,115 +265,111 @@ def __init__(
) -> None:
super().__init__()

self._field_cls = field_cls
self._modality = modality
self._field_config = field_config
self.field_cls = field_cls
self.modality = modality
self.field_config = field_config
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

def build_items(
def build_elems(
self,
key: str,
batch: NestedTensors,
) -> list[MultiModalFieldItem]:
field = self._field_cls(key=key, modality=self._modality)
return field.build_items(batch, **self._field_config) # type: ignore
) -> list[MultiModalFieldElem]:
field = self.field_cls(key=key, modality=self.modality)
return field.build_elems(batch, **self.field_config) # type: ignore


class MultiModalKwargs(UserDict[str, NestedTensors]):
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
A collection of :class:`MultiModalFieldElem`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.

The metadata :code:`items_by_key` defines how to split batched keyword
arguments corresponding to each data item in :class:`MultiModalDataItems`:

- For a keyword argument, we can access the :code:`i` th item in the batch
via :code:`items_by_key[key][i]`.
- We can gather the keyword arguments belonging to a modality by finding
the keys with items that belong to that modality, then accessing
the :code:`i` th item in the batch for each such key.
@staticmethod
def from_elems(elems: list[MultiModalFieldElem]) -> "MultiModalKwargsItem":
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})

Example:
@property
def modality(self) -> str:
modalities = {elem.field.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities))

.. code-block:: python

# All items belong to the "image" modality
items_by_key={
"pixel_values": [a, b, c, d], # "image" modality
"image_grid_thw": [e, f, g, h], # "image" modality
"pixel_values_video": [h, i, j], # "video" modality
"video_grid_thw": [k, l, m], # "video" modality
}
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.

- The keyword arguments belonging to the first image are
:code:`{"pixel_values": a, "image_grid_thw": e}`.
- The keyword arguments belonging to the second video are
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
The metadata :code:`items` enables us to obtain the keyword arguments
corresponding to each data item in :class:`MultiModalDataItems`, via
:meth:`get_num_items` and :meth:`get_item`.
"""

@staticmethod
def from_hf_inputs(
hf_inputs: BatchFeature,
config_by_key: Mapping[str, MultiModalFieldConfig],
*,
enable_sanity_checks: bool = False,
):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
items_by_key = {
key: config.build_items(key, batch)
for key, config in config_by_key.items()
if (batch := hf_inputs.get(key)) is not None
}

return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
elems_by_key = dict[str, list[MultiModalFieldElem]]()
keys_by_modality = defaultdict[str, set[str]](set)
for key, config in config_by_key.items():
batch = hf_inputs.get(key)
if batch is not None:
elems = config.build_elems(key, batch)
if len(elems) > 0:
elems_by_key[key] = elems
keys_by_modality[config.modality].add(key)

items = list[MultiModalKwargsItem]()
for modality, keys in keys_by_modality.items():
elems_in_modality = {k: elems_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

if len(set(batch_sizes.values())) > 1:
raise ValueError(
f"Cannot merge different batch sizes for {modality=}! "
f"Found: {batch_sizes=}")

batch_size = next(iter(batch_sizes.values()))
for item_idx in range(batch_size):
elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems))

return MultiModalKwargs.from_items(items)

@staticmethod
def from_items_by_key(
items_by_key: Mapping[str, list[MultiModalFieldItem]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
def from_items(items: list[MultiModalKwargsItem]) -> "MultiModalKwargs":
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)

data = {
key: items[0].field.reduce(items).data
for key, items in items_by_key.items() if len(items) > 0
key: elems[0].field.reduce(elems).data
for key, elems in elems_by_key.items() if len(elems) > 0
}

return MultiModalKwargs(data,
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)
return MultiModalKwargs(data, items=items)

def __init__(
self,
data: Mapping[str, NestedTensors],
*,
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
enable_sanity_checks: bool = False,
items: Optional[Sequence[MultiModalKwargsItem]] = None,
) -> None:
super().__init__(data)

# Shallow copy to avoid footgun in case a defaultdict is passed in
self._items_by_key = dict(items_by_key)
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
self._items_by_modality = dict(items_by_modality)

keys_by_modality = defaultdict[str, set[str]](set)
for key, items in items_by_key.items():
for item in items:
keys_by_modality[item.field.modality].add(key)

self._keys_by_modality = dict(keys_by_modality)

if enable_sanity_checks:
for modality, keys in keys_by_modality.items():
items_in_modality = {k: items_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size
for bs in batch_sizes.values()), dict(
modality=modality,
batch_sizes=batch_sizes,
items_by_key=items_by_key)
@property
def modalities(self):
return self._items_by_modality.keys()
Comment on lines +370 to +372
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@property
def modalities(self):
return self._items_by_modality.keys()
@property
def modalities(self):
return list(self._items_by_modality.keys())

This is probably more intuitive?


@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
Expand Down Expand Up @@ -452,58 +446,38 @@ def as_kwargs(
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
if self._items_by_key != other._items_by_key:
if self._items_by_modality != other._items_by_modality:
return False

ks = self.keys()
return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks))

def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
return self._items_by_key[key][item_index]
def get_num_items(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
if not self._items_by_modality:
raise RuntimeError(
"`get_num_items` is not supported when "
"MultiModalKwargs is not initialized with `items`")

def get_items_by_modality(
self,
modality: str,
item_index: int,
) -> Mapping[str, MultiModalFieldItem]:
return len(self._items_by_modality[modality])

def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
if modality not in self._keys_by_modality:
available_modalities = set(self._keys_by_modality.keys())
if not self._items_by_modality:
raise RuntimeError(
"`get_item` is not supported when "
"MultiModalKwargs is not initialized with `items`")
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

if modality not in self._items_by_modality:
available_modalities = set(self._items_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")

keys_to_gather = self._keys_by_modality[modality]

return {
key: self.get_item(key, item_index)
for key in keys_to_gather if key in self
}

@staticmethod
def from_items_by_modality(
items_by_modality: Mapping[str, list[Mapping[str,
MultiModalFieldItem]]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
"""
Construct a new :class:`MultiModalKwargs` from multiple items returned
by :meth:`get_fields_by_modality`.
"""
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
for fields in items_by_modality.values():
for field in fields:
for k, v in field.items():
items_by_key[k].append(v)

return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
return self._items_by_modality[modality][item_index]


MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
Expand Down
Loading
Loading