Skip to content

Commit

Permalink
13rps
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aqiao committed Dec 16, 2024
1 parent 04a0ef4 commit fd4ed14
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/offline_inference_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
),
decoder_prompt="<|startoftranscript|>",
),
] #* 128
] * 128

# Create a sampling params object.
sampling_params = SamplingParams(
Expand Down
3 changes: 0 additions & 3 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,7 @@ def forward(
out=prefill_output,
)

print("METADATA", attn_metadata)

if decode_meta := attn_metadata.decode_metadata:
print("DECODE_META", decode_meta)
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,7 @@ def schedule(
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=(seq_group.multi_modal_data or seq_group.encoder_seq.multi_modal_data)
,#if scheduler_outputs.num_prefill_groups > 0 else None,
if scheduler_outputs.num_prefill_groups > 0 else None,
multi_modal_placeholders=seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
Expand Down
62 changes: 47 additions & 15 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,39 @@ def __init__(
cache_config=cache_config,
prefix=prefix,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata = None,
encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
q, _ = self.q_proj(hidden_states)

if encoder_hidden_states is not None:
k, _ = self.k_proj(encoder_hidden_states)
v, _ = self.v_proj(encoder_hidden_states)
else:
k = v = None

attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)

output, _ = self.out_proj(attn_output)

return output

output, _ = self.out_proj(attn_output)
# HACK
query_lens = attn_metadata.query_start_loc.diff().tolist()
hidden_states = list(hidden_states.split(query_lens))
Expand Down Expand Up @@ -404,7 +430,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
Expand All @@ -422,6 +448,7 @@ def forward(
hidden_states = self.encoder_attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -470,7 +497,6 @@ def forward(
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
print("INPUTS EMBEDS", inputs_embeds.size())

embed_pos = self.embed_positions.weight

Expand Down Expand Up @@ -511,7 +537,7 @@ def forward(
self,
input_ids,
positions: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
Expand All @@ -524,7 +550,7 @@ def forward(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata
attn_metadata=attn_metadata,
)

hidden_states = self.layer_norm(hidden_states)
Expand All @@ -541,17 +567,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

def forward(
self,
input_features: torch.FloatTensor,
input_features: Optional[torch.FloatTensor],
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
encoder_outputs = self.encoder(
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
) -> torch.Tensor:
if input_features is not None:
# Prefill encoder kv-caches
encoder_outputs = self.encoder(
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
else:
encoder_outputs = None

decoder_outputs = self.decoder(
input_ids=input_ids,
Expand Down Expand Up @@ -677,9 +707,11 @@ def forward(
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
print(attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_start_loc)
input_features = kwargs.get("input_features")
if input_features is not None:
input_features = input_features.to(torch.float16)
decoder_outputs = self.model(
input_features=kwargs["input_features"].to(torch.float16),
input_features=input_features,
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
Expand Down

0 comments on commit fd4ed14

Please sign in to comment.