Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Move the _touch(computed_blocks) call in the allocate_slots method to after the check for allocating new blocks. #11565

Merged
merged 6 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def test_prefill():
# Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block = manager.get_computed_blocks(req2)
computed_blocks = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_block] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]
Expand Down Expand Up @@ -500,3 +500,62 @@ def test_mm_prefix_caching():
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3


def test_prefill_not_enough_free_blocks_with_computed_blocks():
"""
This is a unit test that tests the correctness of the allocate_slots
when there is not enough free blocks. Specifically, when a request
has computed blocks but cannot be allocated due to not enough free blocks,
the computed blocks should not be touched.
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id]

# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
computed_blocks = manager.get_computed_blocks(req1)
assert computed_blocks == block_part0
manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
manager.free(req1)
assert {block.ref_cnt for block in block_part1[:3]} == {1}
assert {block.ref_cnt for block in block_part1[3:]} == {0}

# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
computed_blocks = manager.get_computed_blocks(req2)
assert not computed_blocks
manager.allocate_slots(req2, block_size * 2, computed_blocks)

# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None
# Block 0-2 are used by Req 1.
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
assert {block.ref_cnt for block in block_part1[3:]} == {0}
19 changes: 13 additions & 6 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def allocate_slots(
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: The blocks that have already been computed.
computed_blocks: A list of computed blocks.

Returns:
A list of new allocated blocks.
Expand All @@ -200,6 +200,18 @@ def allocate_slots(
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")

# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
comaniac marked this conversation as resolved.
Show resolved Hide resolved
comaniac marked this conversation as resolved.
Show resolved Hide resolved

num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks.
return None

# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(computed_blocks)
Expand All @@ -208,11 +220,6 @@ def allocate_slots(
"Computed blocks should be empty when "
"prefix caching is disabled")

num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks):
# Cannot allocate new blocks.
return None

# Determine the number of new blocks to allocate considering
# preallocated blocks.
num_new_blocks = min(
Expand Down
Loading