Skip to content

Commit

Permalink
[V1] Support audio language models on V1 (#11733)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored Jan 7, 2025
1 parent 869e829 commit 2de197b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `Qwen/Qwen2-Audio-7B-Instruct`
-
- ✅︎
-
- ✅︎
* - `Qwen2VLForConditionalGeneration`
- Qwen2-VL
- T + I<sup>E+</sup> + V<sup>E+</sup>
Expand All @@ -724,7 +724,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `fixie-ai/ultravox-v0_3`
-
- ✅︎
-
- ✅︎
```

<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,16 @@ def _process_audio_input(self,
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.multi_modal_projector(selected_audio_feature)
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_output_lengths = audio_output_lengths.unsqueeze(1)
audio_features_mask = torch.arange(max_audio_tokens).expand(
num_audios, max_audio_tokens
).to(audio_output_lengths.device) < audio_output_lengths.unsqueeze(1)
num_audios, max_audio_tokens).to(
audio_output_lengths.device) < audio_output_lengths
masked_audio_features = audio_features[audio_features_mask].view(
-1, embed_dim)

return masked_audio_features
# Split to tuple of embeddings for individual audio input.
return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist())

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
Expand Down
28 changes: 21 additions & 7 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""

import math
from functools import cached_property
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Expand All @@ -14,6 +13,7 @@
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
Expand All @@ -35,8 +35,11 @@
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings,
merge_multimodal_embeddings_from_map)

_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25


Expand Down Expand Up @@ -64,7 +67,14 @@ def _get_hf_processor(
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> ProcessorMixin:
return self.ctx.get_hf_processor()
hf_processor = self.ctx.get_hf_processor()

# NOTE: Ultravox processing definition uses '<|eot_id|>' as the
# placeholder that will cause confusion with the actual end of turn
# token, thus we override placeholder with a reserved special
# token.
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
return hf_processor

def _get_feature_extractor(
self,
Expand Down Expand Up @@ -465,11 +475,15 @@ def get_input_embeddings(
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:

# TODO(ywang96): use merge_multimodal_embeddings after
# v0 is deprecated
merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
# TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1:
merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
else:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
return inputs_embeds

def forward(self,
Expand Down

0 comments on commit 2de197b

Please sign in to comment.