Skip to content

Commit

Permalink
some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aqiao committed Dec 18, 2024
1 parent 787708a commit e943905
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 102 deletions.
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we n
einops # Required for Qwen2-VL.
compressed-tensors == 0.8.0 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging torch.compile
librosa >= 0.10.2 # required for audio processing including Whisper
4 changes: 2 additions & 2 deletions tests/models/encoder_decoder/audio/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
EXPECTED = {
"openai/whisper-medium": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its fleece was quite as"
" slow, and everywhere that Mary went the lamb was sure to go.",
" of practical poetry. Mary had a little lamb, its fleece was white as"
" snow, and everywhere that Mary went the lamb would shun it all.",
" And the old one pitch on the way to Edgar Martinez swung on the line"
" down the left field line for Obeysmith. Here comes Joy. Here is"
" Jorgen at third base. They're gonna wave him in. The throw to the"
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class _HfExamplesInfo:
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
Expand Down
23 changes: 17 additions & 6 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,16 @@ def _tokenize_prompt(
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()

add_special_tokens = None
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return tokenizer.encode(request_id=request_id,

Check failure on line 193 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Unexpected keyword argument "add_special_tokens" for "encode" of "BaseTokenizerGroup" [call-arg]

Check failure on line 193 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Unexpected keyword argument "add_special_tokens" for "encode" of "BaseTokenizerGroup" [call-arg]

Check failure on line 193 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Unexpected keyword argument "add_special_tokens" for "encode" of "BaseTokenizerGroup" [call-arg]

Check failure on line 193 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Unexpected keyword argument "add_special_tokens" for "encode" of "BaseTokenizerGroup" [call-arg]
prompt=prompt,
lora_request=lora_request)
lora_request=lora_request,
add_special_tokens=add_special_tokens)

async def _tokenize_prompt_async(
self,
Expand All @@ -197,10 +203,15 @@ async def _tokenize_prompt_async(
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()

return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
add_special_tokens = None
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return await tokenizer.encode_async(

Check failure on line 212 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Unexpected keyword argument "add_special_tokens" for "encode_async" of "BaseTokenizerGroup" [call-arg]

Check failure on line 212 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Unexpected keyword argument "add_special_tokens" for "encode_async" of "BaseTokenizerGroup" [call-arg]

Check failure on line 212 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Unexpected keyword argument "add_special_tokens" for "encode_async" of "BaseTokenizerGroup" [call-arg]

Check failure on line 212 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Unexpected keyword argument "add_special_tokens" for "encode_async" of "BaseTokenizerGroup" [call-arg]
request_id=request_id, prompt=prompt, lora_request=lora_request,
add_special_tokens=add_special_tokens)

def _can_process_multimodal(self) -> bool:
model_config = self.model_config
Expand Down
128 changes: 61 additions & 67 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
import numpy as np
import torch
from torch import nn
from transformers import WhisperConfig, WhisperProcessor
from transformers import WhisperProcessor
from transformers.models.whisper.modeling_whisper import sinusoids

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import FastGELU
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
Expand All @@ -24,35 +23,17 @@
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import SequenceData

from .interfaces import SupportsMultiModal
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
from .utils import AutoWeightsLoader, make_layers, WeightsMapper

logger = init_logger(__name__)


def sinusoids(
length: int, channels: int, max_timescale: float = 10000
) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal "
f"positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment *
torch.arange(channels // 2))
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)


class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int,
padding_idx: Optional[int] = None):
Expand Down Expand Up @@ -216,6 +197,39 @@ def forward(
return output


class WhisperMLP(nn.Module):

def __init__(
self,
embed_dim: int,
ffn_dim: int,
act_fn: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()

self.activation_fn = get_act_fn(act_fn)
self.fc1 = ColumnParallelLinear(
input_size=embed_dim,
output_size=ffn_dim,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
input_size=ffn_dim,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)

def forward(self, hidden_states: torch.Tensor):
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states


class WhisperEncoderLayer(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand All @@ -234,20 +248,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = FastGELU()
self.fc1 = ColumnParallelLinear(
input_size = self.embed_dim,
output_size = config.encoder_ffn_dim,
bias = True,
self.mlp = WhisperMLP(
embed_dim=config.d_model,
ffn_dim=config.encoder_ffn_dim,
act_fn=config.activation_function,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
input_size = config.encoder_ffn_dim,
output_size = self.embed_dim,
bias = True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
prefix=f"{prefix}.mlp",
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)

Expand All @@ -267,9 +273,7 @@ def forward(
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

if hidden_states.dtype == torch.float16 and (
Expand All @@ -291,41 +295,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.embed_dim = config.d_model
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
embed_dim=config.d_model,
num_heads=config.decoder_attention_heads,
attn_type=AttentionType.DECODER,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.activation_fn = FastGELU()

self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
self.encoder_attn = WhisperCrossAttention(
embed_dim=self.embed_dim,
embed_dim=config.d_model,
num_heads=config.decoder_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.encoder_attn",
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = ColumnParallelLinear(
input_size = self.embed_dim,
output_size = config.decoder_ffn_dim,
bias = True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
input_size = config.decoder_ffn_dim,
output_size = self.embed_dim,
bias = True,
self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
self.mlp = WhisperMLP(
embed_dim=config.d_model,
ffn_dim=config.decoder_ffn_dim,
act_fn=config.activation_function,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
prefix=f"{prefix}.mlp",
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
self.final_layer_norm = nn.LayerNorm(config.d_model)

def forward(
self,
Expand Down Expand Up @@ -355,9 +349,7 @@ def forward(

residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states
Expand Down Expand Up @@ -685,5 +677,7 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
return loader.load_weights((name, loaded_weight)
for name, loaded_weight in weights)
loaded_weights = [(name, loaded_weight)
for name, loaded_weight in weights]
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
return loader.load_weights(loaded_weights, mapper=mapper)
7 changes: 0 additions & 7 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
lora_config: LoRAConfig):
add_special_tokens = None
if model_config.hf_config.model_type == "whisper":
# For Whisper models, the special tokens should be provided by the user
# based on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
add_special_tokens=add_special_tokens,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)
Expand Down
26 changes: 16 additions & 10 deletions vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def _finalize_encode(self, actor: ray.ObjectRef,
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
Expand All @@ -132,7 +133,8 @@ def encode(self,
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
lora_request=lora_request,
add_special_tokens=add_special_tokens))
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
Expand All @@ -143,7 +145,8 @@ def encode(self,
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
lora_request=lora_request,
add_special_tokens=add_special_tokens))
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
Expand All @@ -160,7 +163,8 @@ async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
Expand All @@ -177,19 +181,21 @@ async def encode_async(
actor_is_alive = True
original_actor = actor
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
ret = await actor.encode.remote(
request_id=request_id, prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
ret = await actor.encode.remote(
request_id=request_id, prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
Expand Down
Loading

0 comments on commit e943905

Please sign in to comment.