Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Attention.forward with unified_attention when use_direct_call=True #11967

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,10 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
_kv_cache: torch.Tensor,
_attn_metadata: AttentionMetadata,
) -> torch.Tensor:

if self.use_direct_call:
return self.impl.forward(query, key, value, kv_cache,
attn_metadata, self._k_scale,
self._v_scale)
elif self.use_output:
if self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
# Reshape the query, key, and value tensors.
Expand All @@ -154,12 +149,19 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
if self.use_direct_call:
unified_attention_with_output(query, key, value, output,
self.layer_name)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
self.layer_name)
if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down
1 change: 0 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,5 +2171,4 @@ def bind_kv_cache(
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
assert forward_ctx.kv_cache[ve].numel() == 0
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
13 changes: 11 additions & 2 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
Expand All @@ -40,7 +41,8 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -1286,6 +1288,9 @@ def create_dummy_seq_group_metadata(self,
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[kv_caches])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this extra list construction for pipeline parallel? if yes, can you add comments here?

max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
Expand Down Expand Up @@ -1943,7 +1948,11 @@ def execute_model(
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
with self.profiler.record_event('internal', model_event_name):
with set_forward_context(
model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine), \
self.profiler.record_event(
'internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor import set_random_seed
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.hpu_model_runner import HPUModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
Expand Down Expand Up @@ -215,6 +216,8 @@ def _init_cache_engine(self):
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.hpu_cache)

def _warm_up_model(self) -> None:
# NOTE(kzawora): We should use virtual engine index here
Expand Down
17 changes: 10 additions & 7 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers_neuronx.config import GenerationConfig

from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -314,13 +315,15 @@ def execute_model(
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")

hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
)
with set_forward_context(None, self.vllm_config, 0):
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
)

# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -350,7 +351,8 @@ def execute_model(
device=self.device),
}

hidden_states = model_executable(**execute_model_kwargs)
with set_forward_context(attn_metadata, self.vllm_config, 0):
hidden_states = model_executable(**execute_model_kwargs)

# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
Expand Down
13 changes: 11 additions & 2 deletions vllm/worker/openvino_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.utils import bind_kv_cache
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

Expand Down Expand Up @@ -339,6 +340,8 @@ def _init_cache_engine(self) -> None:
ov_device,
)
self.kv_cache = self.cache_engine.kv_cache
bind_kv_cache(self.compilation_config.static_forward_context,
[self.kv_cache])
self.model_runner.block_size = self.cache_engine.block_size

assert self.kv_cache is not None
Expand Down Expand Up @@ -507,12 +510,18 @@ def model_profile_run():

self.model_runner.block_size = tmp_cache_config.block_size

bind_kv_cache(self.compilation_config.static_forward_context,
profiling_cache_engine.kv_cache)
# Run the model with the dummy inputs.
self.model_runner.execute_model(seqs,
profiling_cache_engine.kv_cache)

# explicitly delete temporary KV cache manager to free KV cache
# when real inputs will be passed to OV
# Explicitly revert bind_kv_cache and delete temporary KV cache
# manager to free KV cache when real inputs will be passed to OV
bind_kv_cache(self.compilation_config.static_forward_context, [[
torch.tensor([])
for _ in range(len(profiling_cache_engine.kv_cache))
]])
del profiling_cache_engine

logger.info(
Expand Down
28 changes: 18 additions & 10 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -265,8 +266,9 @@ def _dummy_run(
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(token_ids, position_ids, attn_metadata, input_lens, t,
p, num_samples, kv_caches)

def warmup_model(
self,
Expand Down Expand Up @@ -663,10 +665,13 @@ def execute_model(
input_lens = model_input.input_lens[i:i + 1].to(self.device)
t = model_input.t[i:i + 1].to(self.device)
p = model_input.p[i:i + 1].to(self.device)
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t,
p, model_input.num_samples,
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx

Expand Down Expand Up @@ -711,10 +716,13 @@ def execute_model(
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t,
p, model_input.num_samples,
kv_caches)
self.cached_step_outputs.append(output_token_ids)

if i < num_steps - 1:
Expand Down
6 changes: 5 additions & 1 deletion vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
Expand Down Expand Up @@ -108,6 +108,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
bind_kv_cache(self.compilation_config.static_forward_context,
[kv_caches])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TPU kv caches are actually two separate tensors. would that matter here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not matter. The detailed structure of kv_cache in one layer (e.g., torch.Tensor/Tuple[torch.Tensor]) will not be accessed outside Attention backends.

self.model_runner._dummy_run(
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
Expand Down Expand Up @@ -170,6 +172,8 @@ def initialize_cache(
device="cpu")
cpu_v_cache = torch.zeros_like(cpu_k_cache)
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
bind_kv_cache(self.compilation_config.static_forward_context,
[self.tpu_cache])
self._warmup_model()

def _warmup_model(self) -> None:
Expand Down
21 changes: 12 additions & 9 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.attention import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
Expand Down Expand Up @@ -562,15 +563,17 @@ def execute_model(
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()

hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device))
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device))
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
Expand Down
Loading