Skip to content

Commit

Permalink
hide kv cache behind torch.compile
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Zhang <[email protected]>
  • Loading branch information
heheda12345 committed Jan 1, 2025
1 parent 74fa1d1 commit 176dc6d
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 39 deletions.
36 changes: 18 additions & 18 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config import (CacheConfig, LayerForwardContext,
get_current_vllm_config)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -117,7 +118,10 @@ def __init__(
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# use a placeholder kv cache tensor during init, which will be replaced
# after kv cache initialization
compilation_config.static_forward_context[
prefix] = LayerForwardContext(self, torch.tensor([]))
self.layer_name = prefix

def forward(
Expand Down Expand Up @@ -152,13 +156,11 @@ def forward(
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, attn_type,
self.layer_name)
query, key, value, output, attn_type, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
attn_type, self.layer_name)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down Expand Up @@ -236,17 +238,17 @@ def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
ctx = forward_context.layers[layer_name]
self = ctx.attn_module
return self.impl.forward(query,
key,
value,
kv_cache,
ctx.kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
Expand All @@ -257,7 +259,6 @@ def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
Expand All @@ -267,7 +268,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
Expand All @@ -278,17 +279,17 @@ def unified_attention_with_output(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
ctx = forward_context.layers[layer_name]
self = ctx.attn_module
self.impl.forward(query,
key,
value,
kv_cache,
ctx.kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
Expand All @@ -301,7 +302,6 @@ def unified_attention_with_output_fake(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
Expand All @@ -311,7 +311,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
12 changes: 9 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,6 +2616,12 @@ class CompilationLevel:
PIECEWISE = 3


@dataclass
class LayerForwardContext:
attn_module: Any # vllm.attention.layer.Attention
kv_cache: Any # torch.Tensor


class CompilationConfig(BaseModel):
"""
Configuration for compilation.
Expand Down Expand Up @@ -2769,9 +2775,9 @@ def model_post_init(self, __context: Any) -> None:
inductor_hash_cache: Any = PrivateAttr

# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
# Map from layer name to the layer's forward context, which stores
# attention cls and kv_cache
static_forward_context: Dict[str, LayerForwardContext] = PrivateAttr

def compute_hash(self) -> str:
"""
Expand Down
30 changes: 16 additions & 14 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

import torch

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config import LayerForwardContext, VllmConfig
from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
Expand All @@ -21,9 +24,10 @@

@dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# copy from vllm_config.compilation_config.static_forward_context
layers: Dict[str, LayerForwardContext]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass


_forward_context: Optional[ForwardContext] = None
Expand All @@ -38,34 +42,32 @@ def get_forward_context() -> ForwardContext:


@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and context is not None
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
layers=vllm_config.compilation_config.static_forward_context,
attn_metadata=attn_metadata)
try:
yield
finally:
global batchsize_counter
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(context, "num_prefill_tokens"):
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + \
context.num_decode_tokens
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
Expand Down
11 changes: 11 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
Expand Down Expand Up @@ -752,3 +753,13 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
# register kv_cache for forward_context
if self.vllm_config.parallel_config.pipeline_parallel_size > 1:
# TODO(Chen): In pipeline parallelism, layer_name 'layers.i.xxx'
# is mapped to kv_caches[i - start_layer_idx]. Need to implement
# and verify after supporting PP in v1
raise NotImplementedError("Pipeline parallelism is not supported.")
ctx = self.vllm_config.compilation_config.static_forward_context
for layer_name, forward_ctx in ctx.items():
layer_id = extract_layer_index(layer_name)
forward_ctx.kv_cache = self.kv_caches[layer_id]
16 changes: 14 additions & 2 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""CacheEngine class for managing the KV cache."""
from typing import List
from typing import Dict, List

import torch

from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LayerForwardContext, ModelConfig, ParallelConfig)
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
get_dtype_size, is_pin_memory_available)

Expand All @@ -26,6 +28,7 @@ def __init__(
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
compilation_config: CompilationConfig,
) -> None:
self.cache_config = cache_config
self.model_config = model_config
Expand Down Expand Up @@ -62,6 +65,7 @@ def __init__(
self.gpu_cache = self._allocate_kv_cache(
self.num_gpu_blocks, self.device_config.device_type)
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
self._register_gpu_kv_cache(compilation_config.static_forward_context)

def _allocate_kv_cache(
self,
Expand All @@ -84,6 +88,14 @@ def _allocate_kv_cache(
device=device))
return kv_cache

def _register_gpu_kv_cache(self, ctx: Dict[str,
LayerForwardContext]) -> None:
if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError
for layer_name, forward_ctx in ctx.items():
layer_id = extract_layer_index(layer_name)
forward_ctx.kv_cache = self.gpu_cache[layer_id]

def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
HPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
self.parallel_config, self.device_config,
self.compilation_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.hpu_cache = [
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
CacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
self.parallel_config, self.device_config,
self.compilation_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
Expand Down
1 change: 1 addition & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config
from vllm.platforms import current_platform
self.current_platform = current_platform

Expand Down

0 comments on commit 176dc6d

Please sign in to comment.