diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 614674375786e..e008a56de6208 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -142,12 +142,18 @@ class that Attention will automatically select when it is constructed. torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) # Construct KV cache - kv_cache = make_kv_cache(test_pt.num_blocks, - test_pt.num_heads, - test_pt.head_size, - test_pt.block_size, - device=CUDA_DEVICE, - backend=test_pt.backend_name) + if test_pt.attn_type in (AttentionType.DECODER, + AttentionType.ENCODER_DECODER): + kv_cache = make_kv_cache(test_pt.num_blocks, + test_pt.num_heads, + test_pt.head_size, + test_pt.block_size, + device=CUDA_DEVICE, + backend=test_pt.backend_name) + else: + kv_cache = torch.tensor([]) + + attn.kv_cache = [kv_cache] return TestResources(scale, attn, kv_cache) diff --git a/tests/test_utils.py b/tests/test_utils.py index 14d2fbd63b90d..6810e0302f897 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,9 +7,11 @@ import torch from vllm_test_utils import monitor +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (FlexibleArgumentParser, PlaceholderModule, - StoreBoolean, deprecate_kwargs, get_open_port, - memory_profiling, merge_async_iterators, supports_kw) + StoreBoolean, bind_kv_cache, deprecate_kwargs, + get_open_port, memory_profiling, merge_async_iterators, + supports_kw) from .utils import error_on_warning, fork_new_process_for_each_test @@ -325,6 +327,85 @@ def measure_current_non_torch(): lib.cudaFree(handle2) +def test_bind_kv_cache(): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), + } + kv_cache = [ + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + ] + bind_kv_cache(ctx, [kv_cache]) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] + assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] + assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] + assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] + +def test_bind_kv_cache_non_attention(): + from vllm.attention import Attention + + # example from Jamba PP=2 + ctx = { + 'model.layers.20.attn': Attention(32, 128, 0.1), + 'model.layers.28.attn': Attention(32, 128, 0.1), + } + kv_cache = [ + torch.zeros((1, )), + torch.zeros((1, )), + ] + bind_kv_cache(ctx, [kv_cache]) + assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0] + assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] + + +def test_bind_kv_cache_encoder_decoder(): + from vllm.attention import Attention, AttentionType + + # example from bart + ctx = { + 'encoder.layers.0.self_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), + 'decoder.layers.0.encoder_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), + 'decoder.layers.0.self_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), + } + + kv_cache = [ + torch.zeros((1, )), + ] + encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache + + bind_kv_cache(ctx, [kv_cache]) + assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache + assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0] + assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0] + + +def test_bind_kv_cache_pp(): + cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) + with set_current_vllm_config(cfg): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + } + kv_cache = [ + [torch.zeros((1, ))], + [torch.zeros((1, ))] + ] + bind_kv_cache(ctx, kv_cache) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0] + assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] + + def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 8dd9b23fbdd5f..5b1732036e807 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -4,6 +4,7 @@ import pytest from transformers import AutoTokenizer +from tests.utils import fork_new_process_for_each_test from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest: ) +@fork_new_process_for_each_test def test_engine_core(monkeypatch): with monkeypatch.context() as m: @@ -138,6 +140,7 @@ def test_engine_core(monkeypatch): assert len(engine_core.scheduler.running) == 0 +@fork_new_process_for_each_test def test_engine_core_advanced_sampling(monkeypatch): """ A basic end-to-end test to verify that the engine functions correctly diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 5a21806e57a11..7eac16f2cf542 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -6,6 +6,7 @@ import pytest from transformers import AutoTokenizer +from tests.utils import fork_new_process_for_each_test from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): break +@fork_new_process_for_each_test @pytest.mark.parametrize("multiprocessing_mode", [True, False]) def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): @@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): client.abort_requests([request.request_id]) +@fork_new_process_for_each_test @pytest.mark.asyncio async def test_engine_core_client_asyncio(monkeypatch): diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f1b3598e60b54..55e4e14027f79 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -121,6 +121,13 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix self.attn_type = attn_type + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + # this variable will not be accessed if use_direct_call is True + self.kv_cache = [ + torch.tensor([]) for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] def forward( self, @@ -148,11 +155,11 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) torch.ops.vllm.unified_attention_with_output( - query, key, value, output, kv_cache, self.layer_name) + query, key, value, output, self.layer_name) return output.view(-1, hidden_size) else: return torch.ops.vllm.unified_attention(query, key, value, - kv_cache, self.layer_name) + self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore @@ -230,12 +237,12 @@ def unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, layer_name: str, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.dynamic_forward_context - self = forward_context.static_forward_context[layer_name] + attn_metadata = forward_context.attn_metadata + self = forward_context.attn_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(query, key, value, kv_cache, attn_metadata, self._k_scale, self._v_scale) @@ -244,7 +251,6 @@ def unified_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, layer_name: str, ) -> torch.Tensor: return torch.empty_like(query).contiguous() @@ -253,7 +259,7 @@ def unified_attention_fake( direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, - mutates_args=["kv_cache"], + mutates_args=[], fake_impl=unified_attention_fake, dispatch_key=current_platform.dispatch_key, ) @@ -264,12 +270,12 @@ def unified_attention_with_output( key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, - kv_cache: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.dynamic_forward_context - self = forward_context.static_forward_context[layer_name] + attn_metadata = forward_context.attn_metadata + self = forward_context.attn_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(query, key, value, @@ -285,7 +291,6 @@ def unified_attention_with_output_fake( key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, - kv_cache: torch.Tensor, layer_name: str, ) -> None: return @@ -294,7 +299,7 @@ def unified_attention_with_output_fake( direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, - mutates_args=["kv_cache", "output"], + mutates_args=["output"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/config.py b/vllm/config.py index 19609085cc960..13b5390008a35 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2780,7 +2780,6 @@ def model_post_init(self, __context: Any) -> None: compilation_time: float = PrivateAttr # Per-model forward context - # Mainly used to store attention cls # Map from layer name to the attention cls static_forward_context: Dict[str, Any] = PrivateAttr diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7f56575279e9b..828b394ec5d21 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -2,7 +2,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional import torch @@ -10,6 +10,9 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + logger = init_logger(__name__) track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 @@ -21,9 +24,12 @@ @dataclass class ForwardContext: - static_forward_context: Dict[str, Any] + # copy from vllm_config.compilation_config.static_forward_context + attn_layers: Dict[str, Any] # TODO: extend to support per-layer dynamic forward context - dynamic_forward_context: Any + attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + # TODO: remove after making all virtual_engines share the same kv cache + virtual_engine: int # set dynamically for each forward pass _forward_context: Optional[ForwardContext] = None @@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext: @contextmanager -def set_forward_context(context: Any, vllm_config: VllmConfig): +def set_forward_context(attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. """ global forward_start_time - need_to_track_batchsize = track_batchsize and context is not None + need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() global _forward_context prev_context = _forward_context _forward_context = ForwardContext( - static_forward_context=vllm_config.compilation_config. - static_forward_context, - dynamic_forward_context=context) + attn_layers=vllm_config.compilation_config.static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata) try: yield finally: - global batchsize_counter global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: - if hasattr(context, "num_prefill_tokens"): + if hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends - batchsize = context.num_prefill_tokens + \ - context.num_decode_tokens + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens else: # for v1 attention backends - batchsize = context.num_input_tokens + batchsize = attn_metadata.num_input_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch diff --git a/vllm/utils.py b/vllm/utils.py index 487088591ebc2..8c3e5200b3d98 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2138,3 +2138,38 @@ def get_mp_context(): _check_multiproc_method() mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD return multiprocessing.get_context(mp_method) + + +def bind_kv_cache( + ctx: Dict[str, Any], + kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index] +) -> None: + # Bind the kv_cache tensor to Attention modules, similar to + # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] + # Special things handled here: + # 1. Some models have non-attention layers, e.g., Jamba + # 2. Pipeline parallelism, each rank only has a subset of layers + # 3. Encoder attention has no kv cache + # 4. Encoder-decoder models, encoder-decoder attention and decoder-only + # attention of the same layer (e.g., bart's decoder.layers.1.self_attn + # and decoder.layers.1.encoder_attn) is mapped to the same kv cache + # tensor + from vllm.attention import AttentionType + from vllm.model_executor.models.utils import extract_layer_index + layer_need_kv_cache = [ + layer_name for layer_name in ctx + if ctx[layer_name].attn_type in (AttentionType.DECODER, + AttentionType.ENCODER_DECODER) + ] + layer_index_sorted = sorted( + set( + extract_layer_index(layer_name) + for layer_name in layer_need_kv_cache)) + for layer_name in layer_need_kv_cache: + kv_cache_idx = layer_index_sorted.index( + extract_layer_index(layer_name)) + forward_ctx = ctx[layer_name] + assert len(forward_ctx.kv_cache) == len(kv_cache) + for ve, ve_kv_cache in enumerate(kv_cache): + assert forward_ctx.kv_cache[ve].numel() == 0 + forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a1d4f9b135789..fb87dc5a8222a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -16,7 +16,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, cdiv, is_pin_memory_available) + LayerBlockType, bind_kv_cache, cdiv, + is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.engine.mm_input_mapper import MMInputMapperClient @@ -860,3 +861,6 @@ def initialize_kv_cache(self, num_blocks: int) -> None: torch.zeros(kv_cache_shape, dtype=self.kv_cache_dtype, device=self.device)) + bind_kv_cache( + self.vllm_config.compilation_config.static_forward_context, + [self.kv_caches]) diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index cc24cfe04d2ba..fa6775cbd6c66 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -305,7 +305,8 @@ def execute_model( intermediate_tensors, } - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f1531e0fc0675..d99db4e0c6c40 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -526,7 +526,8 @@ def execute_model( execute_model_kwargs.update( {"previous_hidden_states": previous_hidden_states}) - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index 17b2fd2564a04..d31ba89e12375 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -69,7 +69,8 @@ def execute_model( intermediate_tensors, } - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index b5dfebfce6f75..494c6506f3c0f 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner @@ -293,6 +293,8 @@ def _init_cache_engine(self) -> None: self.cache_engine[ve].cpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] + bind_kv_cache(self.compilation_config.static_forward_context, + self.cpu_cache) self.model_runner.block_size = self.cache_engine[0].block_size assert all( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4d5d918087be8..8a161b740042d 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -175,7 +175,8 @@ def execute_model( } if self.has_inner_state else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1c6d1bbee78ee..2b918483d3675 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1527,7 +1527,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(attn_metadata, self.vllm_config, + virtual_engine): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1695,7 +1696,7 @@ def execute_model( if not bypass_model_exec: with set_forward_context(model_input.attn_metadata, - self.vllm_config): + self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index f79b3773bcbd2..6de227f3cb2b9 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -105,7 +105,8 @@ def execute_model( if model_input.token_types is not None: cross_enc_kwargs["token_type_ids"] = model_input.token_types - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f51b51d433d3d..0f12549e3f3fd 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -21,7 +21,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.utils import GiB_bytes, memory_profiling +from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -285,6 +285,8 @@ def _init_cache_engine(self): self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] + bind_kv_cache(self.compilation_config.static_forward_context, + self.gpu_cache) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 249b3ed2dfd37..a835718e1db19 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -43,6 +43,7 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.kv_transfer_config = vllm_config.kv_transfer_config + self.compilation_config = vllm_config.compilation_config from vllm.platforms import current_platform self.current_platform = current_platform