Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Zhang <[email protected]>
  • Loading branch information
heheda12345 committed Jan 8, 2025
1 parent 2cb84f2 commit 76712f8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
40 changes: 29 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, bind_kv_cache,
deprecate_kwargs, get_open_port, memory_profiling,
merge_async_iterators, supports_kw)
Expand Down Expand Up @@ -323,11 +324,11 @@ def test_bind_kv_cache():
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, kv_cache)
assert ctx['layers.0.self_attn'].kv_cache is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache is kv_cache[3]
bind_kv_cache(ctx, [kv_cache])
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]

def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
Expand All @@ -341,9 +342,9 @@ def test_bind_kv_cache_non_attention():
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, kv_cache)
assert ctx['model.layers.20.attn'].kv_cache is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache is kv_cache[1]
bind_kv_cache(ctx, [kv_cache])
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]


def test_bind_kv_cache_encoder_decoder():
Expand All @@ -364,7 +365,24 @@ def test_bind_kv_cache_encoder_decoder():
]
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache

bind_kv_cache(ctx, kv_cache)
bind_kv_cache(ctx, [kv_cache])
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache is kv_cache[0]
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]


def test_bind_kv_cache_pp():
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention

ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
[torch.zeros((1, ))],
[torch.zeros((1, ))]
]
bind_kv_cache(ctx, kv_cache)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
3 changes: 1 addition & 2 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
HPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config,
self.compilation_config)
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.hpu_cache = [
Expand Down

0 comments on commit 76712f8

Please sign in to comment.