Skip to content

Commit

Permalink
[Bugfix] Update attention interface in Whisper (#11784)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored Jan 7, 2025
1 parent b278557 commit 0f3f3c8
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)

def _init_qkv(
Expand Down Expand Up @@ -134,12 +135,7 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=self.attn_type)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)

output, _ = self.out_proj(attn_output)

Expand All @@ -164,6 +160,7 @@ def __init__(
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
attn_type=AttentionType.ENCODER_DECODER,
)

def _init_qkv(
Expand Down Expand Up @@ -207,12 +204,13 @@ def forward(
else:
k = v = None

attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
attn_output = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)

output, _ = self.out_proj(attn_output)

Expand Down Expand Up @@ -734,4 +732,4 @@ def load_weights(self, weights: Iterable[Tuple[str,
loaded_weights = [(name, loaded_weight)
for name, loaded_weight in weights]
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
return loader.load_weights(loaded_weights, mapper=mapper)
return loader.load_weights(loaded_weights, mapper=mapper)

0 comments on commit 0f3f3c8

Please sign in to comment.