Skip to content

Commit

Permalink
support pp & non-attn layers
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Zhang <[email protected]>
  • Loading branch information
heheda12345 committed Jan 2, 2025
1 parent 176dc6d commit c5a5155
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 27 deletions.
30 changes: 28 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio
import os
import socket
from typing import AsyncIterator, Tuple
from typing import TYPE_CHECKING, AsyncIterator, Tuple

import pytest
import torch

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
supports_kw, register_kv_cache)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -306,3 +306,29 @@ def test_memory_profiling():
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)


def test_register_gpu_kv_cache():
from vllm.config import LayerForwardContext
from vllm.attention import Attention

# example from Jamba PP=2
ctx = {
'model.layers.20.attn':
LayerForwardContext(
attn_module=Attention(32, 128, 0.1),
kv_cache=None,
),
'model.layers.28.attn':
LayerForwardContext(
attn_module=Attention(32, 128, 0.1),
kv_cache=None,
)
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
]
register_kv_cache(ctx, kv_cache)
assert ctx['model.layers.20.attn'].kv_cache is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache is kv_cache[1]
14 changes: 13 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from vllm.logger import enable_trace_function_call, init_logger

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import LayerForwardContext, VllmConfig

logger = init_logger(__name__)

Expand Down Expand Up @@ -1947,3 +1947,15 @@ def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)


def register_kv_cache(ctx: Dict[str, "LayerForwardContext"],
kv_cache: List[torch.Tensor]) -> None:
# Two things needed to be handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
from vllm.model_executor.models.utils import extract_layer_index
layer_name_sorted = sorted(ctx.keys(), key=extract_layer_index)
for i, layer_name in enumerate(layer_name_sorted):
forward_ctx = ctx[layer_name]
forward_ctx.kv_cache = kv_cache[i]
16 changes: 5 additions & 11 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
LayerBlockType, cdiv, is_pin_memory_available,
register_kv_cache)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
Expand Down Expand Up @@ -754,12 +754,6 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
dtype=self.kv_cache_dtype,
device=self.device))
# register kv_cache for forward_context
if self.vllm_config.parallel_config.pipeline_parallel_size > 1:
# TODO(Chen): In pipeline parallelism, layer_name 'layers.i.xxx'
# is mapped to kv_caches[i - start_layer_idx]. Need to implement
# and verify after supporting PP in v1
raise NotImplementedError("Pipeline parallelism is not supported.")
ctx = self.vllm_config.compilation_config.static_forward_context
for layer_name, forward_ctx in ctx.items():
layer_id = extract_layer_index(layer_name)
forward_ctx.kv_cache = self.kv_caches[layer_id]
register_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
19 changes: 6 additions & 13 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""CacheEngine class for managing the KV cache."""
from typing import Dict, List
from typing import List

import torch

from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LayerForwardContext, ModelConfig, ParallelConfig)
ModelConfig, ParallelConfig)
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
get_dtype_size, is_pin_memory_available)
get_dtype_size, is_pin_memory_available,
register_kv_cache)

logger = init_logger(__name__)

Expand Down Expand Up @@ -65,7 +65,8 @@ def __init__(
self.gpu_cache = self._allocate_kv_cache(
self.num_gpu_blocks, self.device_config.device_type)
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
self._register_gpu_kv_cache(compilation_config.static_forward_context)
register_kv_cache(compilation_config.static_forward_context,
self.gpu_cache)

def _allocate_kv_cache(
self,
Expand All @@ -88,14 +89,6 @@ def _allocate_kv_cache(
device=device))
return kv_cache

def _register_gpu_kv_cache(self, ctx: Dict[str,
LayerForwardContext]) -> None:
if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError
for layer_name, forward_ctx in ctx.items():
layer_id = extract_layer_index(layer_name)
forward_ctx.kv_cache = self.gpu_cache[layer_id]

def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
Expand Down

0 comments on commit c5a5155

Please sign in to comment.