Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] Hide KV cache behind torch.compile boundary #11677

Merged
merged 22 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
register_kv_cache, supports_kw)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -306,3 +306,29 @@ def test_memory_profiling():
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)


def test_register_gpu_kv_cache():
from vllm.attention import Attention
from vllm.config import LayerForwardContext

# example from Jamba PP=2
ctx = {
'model.layers.20.attn':
LayerForwardContext(
attn_module=Attention(32, 128, 0.1),
kv_cache=None,
),
'model.layers.28.attn':
LayerForwardContext(
attn_module=Attention(32, 128, 0.1),
kv_cache=None,
)
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
]
register_kv_cache(ctx, kv_cache)
assert ctx['model.layers.20.attn'].kv_cache is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache is kv_cache[1]
36 changes: 18 additions & 18 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config import (CacheConfig, LayerForwardContext,
get_current_vllm_config)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -117,7 +118,10 @@ def __init__(
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# use a placeholder kv cache tensor during init, which will be replaced
# after kv cache initialization
compilation_config.static_forward_context[
prefix] = LayerForwardContext(self, torch.tensor([]))
heheda12345 marked this conversation as resolved.
Show resolved Hide resolved
self.layer_name = prefix

def forward(
Expand Down Expand Up @@ -152,13 +156,11 @@ def forward(
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, attn_type,
self.layer_name)
query, key, value, output, attn_type, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
attn_type, self.layer_name)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down Expand Up @@ -236,17 +238,17 @@ def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
ctx = forward_context.layers[layer_name]
self = ctx.attn_module
return self.impl.forward(query,
key,
value,
kv_cache,
ctx.kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
Expand All @@ -257,7 +259,6 @@ def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
Expand All @@ -267,7 +268,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
Expand All @@ -278,17 +279,17 @@ def unified_attention_with_output(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
ctx = forward_context.layers[layer_name]
self = ctx.attn_module
self.impl.forward(query,
key,
value,
kv_cache,
ctx.kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
Expand All @@ -301,7 +302,6 @@ def unified_attention_with_output_fake(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
Expand All @@ -311,7 +311,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
12 changes: 9 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,6 +2616,12 @@ class CompilationLevel:
PIECEWISE = 3


@dataclass
class LayerForwardContext:
attn_module: Any # vllm.attention.layer.Attention
kv_cache: Any # torch.Tensor


class CompilationConfig(BaseModel):
"""
Configuration for compilation.
Expand Down Expand Up @@ -2769,9 +2775,9 @@ def model_post_init(self, __context: Any) -> None:
inductor_hash_cache: Any = PrivateAttr

# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
# Map from layer name to the layer's forward context, which stores
# attention cls and kv_cache
static_forward_context: Dict[str, LayerForwardContext] = PrivateAttr

def compute_hash(self) -> str:
"""
Expand Down
30 changes: 16 additions & 14 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

import torch

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config import LayerForwardContext, VllmConfig
from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
Expand All @@ -21,9 +24,10 @@

@dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# copy from vllm_config.compilation_config.static_forward_context
layers: Dict[str, LayerForwardContext]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass


_forward_context: Optional[ForwardContext] = None
Expand All @@ -38,34 +42,32 @@ def get_forward_context() -> ForwardContext:


@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and context is not None
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
layers=vllm_config.compilation_config.static_forward_context,
attn_metadata=attn_metadata)
try:
yield
finally:
global batchsize_counter
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(context, "num_prefill_tokens"):
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + \
context.num_decode_tokens
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
Expand Down
14 changes: 13 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from vllm.logger import enable_trace_function_call, init_logger

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import LayerForwardContext, VllmConfig

logger = init_logger(__name__)

Expand Down Expand Up @@ -1947,3 +1947,15 @@ def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)


def register_kv_cache(ctx: Dict[str, "LayerForwardContext"],
kv_cache: List[torch.Tensor]) -> None:
# Two things needed to be handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
from vllm.model_executor.models.utils import extract_layer_index
layer_name_sorted = sorted(ctx.keys(), key=extract_layer_index)
for i, layer_name in enumerate(layer_name_sorted):
forward_ctx = ctx[layer_name]
forward_ctx.kv_cache = kv_cache[i]
7 changes: 6 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
LayerBlockType, cdiv, is_pin_memory_available,
register_kv_cache)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
Expand Down Expand Up @@ -752,3 +753,7 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
# register kv_cache for forward_context
register_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
9 changes: 7 additions & 2 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import torch

from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
ModelConfig, ParallelConfig)
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
get_dtype_size, is_pin_memory_available)
get_dtype_size, is_pin_memory_available,
register_kv_cache)

logger = init_logger(__name__)

Expand All @@ -26,6 +28,7 @@ def __init__(
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
compilation_config: CompilationConfig,
) -> None:
self.cache_config = cache_config
self.model_config = model_config
Expand Down Expand Up @@ -62,6 +65,8 @@ def __init__(
self.gpu_cache = self._allocate_kv_cache(
self.num_gpu_blocks, self.device_config.device_type)
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
register_kv_cache(compilation_config.static_forward_context,
self.gpu_cache)

def _allocate_kv_cache(
self,
Expand Down
22 changes: 14 additions & 8 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
ModelConfig, ParallelConfig, VllmConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, register_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
Expand All @@ -33,8 +33,8 @@ class CPUCacheEngine:
"""

def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
parallel_config: ParallelConfig, device_config: DeviceConfig,
compilation_config: CompilationConfig) -> None:
assert device_config.device_type == "cpu"
self.cache_config = cache_config
self.model_config = model_config
Expand Down Expand Up @@ -66,6 +66,8 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,

# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
register_kv_cache(compilation_config.static_forward_context,
self.cpu_cache)

def _allocate_kv_cache(
self,
Expand Down Expand Up @@ -285,9 +287,13 @@ def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:

def _init_cache_engine(self) -> None:
self.cache_engine = [
CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
CPUCacheEngine(
self.cache_config,
self.model_config,
self.parallel_config,
self.device_config,
self.compilation_config,
) for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.cpu_cache = [
self.cache_engine[ve].cpu_cache
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
HPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
self.parallel_config, self.device_config,
self.compilation_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.hpu_cache = [
Expand Down
Loading
Loading