Skip to content

Commit

Permalink
[Kernel] unified_attention for Attention.forward (#11967)
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Zhang <[email protected]>
  • Loading branch information
heheda12345 authored Jan 13, 2025
1 parent 5340a30 commit 0f8cafe
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 45 deletions.
26 changes: 14 additions & 12 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,10 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
_kv_cache: torch.Tensor,
_attn_metadata: AttentionMetadata,
) -> torch.Tensor:

if self.use_direct_call:
return self.impl.forward(query, key, value, kv_cache,
attn_metadata, self._k_scale,
self._v_scale)
elif self.use_output:
if self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
# Reshape the query, key, and value tensors.
Expand All @@ -154,12 +149,19 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
if self.use_direct_call:
unified_attention_with_output(query, key, value, output,
self.layer_name)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
self.layer_name)
if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down
1 change: 0 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,5 +2171,4 @@ def bind_kv_cache(
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
assert forward_ctx.kv_cache[ve].numel() == 0
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
13 changes: 11 additions & 2 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
Expand All @@ -40,7 +41,8 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -1286,6 +1288,9 @@ def create_dummy_seq_group_metadata(self,
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[kv_caches])
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
Expand Down Expand Up @@ -1943,7 +1948,11 @@ def execute_model(
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
with self.profiler.record_event('internal', model_event_name):
with set_forward_context(
model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine), \
self.profiler.record_event(
'internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor import set_random_seed
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.hpu_model_runner import HPUModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
Expand Down Expand Up @@ -215,6 +216,8 @@ def _init_cache_engine(self):
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.hpu_cache)

def _warm_up_model(self) -> None:
# NOTE(kzawora): We should use virtual engine index here
Expand Down
17 changes: 10 additions & 7 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers_neuronx.config import GenerationConfig

from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -314,13 +315,15 @@ def execute_model(
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")

hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
)
with set_forward_context(None, self.vllm_config, 0):
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
)

# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -350,7 +351,8 @@ def execute_model(
device=self.device),
}

hidden_states = model_executable(**execute_model_kwargs)
with set_forward_context(attn_metadata, self.vllm_config, 0):
hidden_states = model_executable(**execute_model_kwargs)

# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
Expand Down
13 changes: 11 additions & 2 deletions vllm/worker/openvino_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.utils import bind_kv_cache
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

Expand Down Expand Up @@ -339,6 +340,8 @@ def _init_cache_engine(self) -> None:
ov_device,
)
self.kv_cache = self.cache_engine.kv_cache
bind_kv_cache(self.compilation_config.static_forward_context,
[self.kv_cache])
self.model_runner.block_size = self.cache_engine.block_size

assert self.kv_cache is not None
Expand Down Expand Up @@ -507,12 +510,18 @@ def model_profile_run():

self.model_runner.block_size = tmp_cache_config.block_size

bind_kv_cache(self.compilation_config.static_forward_context,
profiling_cache_engine.kv_cache)
# Run the model with the dummy inputs.
self.model_runner.execute_model(seqs,
profiling_cache_engine.kv_cache)

# explicitly delete temporary KV cache manager to free KV cache
# when real inputs will be passed to OV
# Explicitly revert bind_kv_cache and delete temporary KV cache
# manager to free KV cache when real inputs will be passed to OV
bind_kv_cache(self.compilation_config.static_forward_context, [[
torch.tensor([])
for _ in range(len(profiling_cache_engine.kv_cache))
]])
del profiling_cache_engine

logger.info(
Expand Down
28 changes: 18 additions & 10 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -265,8 +266,9 @@ def _dummy_run(
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(token_ids, position_ids, attn_metadata, input_lens, t,
p, num_samples, kv_caches)

def warmup_model(
self,
Expand Down Expand Up @@ -663,10 +665,13 @@ def execute_model(
input_lens = model_input.input_lens[i:i + 1].to(self.device)
t = model_input.t[i:i + 1].to(self.device)
p = model_input.p[i:i + 1].to(self.device)
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t,
p, model_input.num_samples,
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx

Expand Down Expand Up @@ -711,10 +716,13 @@ def execute_model(
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t,
p, model_input.num_samples,
kv_caches)
self.cached_step_outputs.append(output_token_ids)

if i < num_steps - 1:
Expand Down
6 changes: 5 additions & 1 deletion vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
Expand Down Expand Up @@ -108,6 +108,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
bind_kv_cache(self.compilation_config.static_forward_context,
[kv_caches])
self.model_runner._dummy_run(
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
Expand Down Expand Up @@ -170,6 +172,8 @@ def initialize_cache(
device="cpu")
cpu_v_cache = torch.zeros_like(cpu_k_cache)
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
bind_kv_cache(self.compilation_config.static_forward_context,
[self.tpu_cache])
self._warmup_model()

def _warmup_model(self) -> None:
Expand Down
21 changes: 12 additions & 9 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.attention import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
Expand Down Expand Up @@ -562,15 +563,17 @@ def execute_model(
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()

hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device))
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device))
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
Expand Down

0 comments on commit 0f8cafe

Please sign in to comment.