You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When V1 and prefix_caching are enabled, when allocating slots for requests in the waiting queue, the function immediately _touch(computed_blocks) at the beginning, incrementing their ref_cnt by 1. If free_block_queue.num_free_blocks is insufficient to allocate new blocks, the function returns None, causing the scheduling loop to break directly. At this point, the request remains in the waiting queue, but the ref_cnt of all blocks in computed_blocks has been incremented. This situation recurs, eventually leading to some blocks having a ref_cnt greater than 0 even though no request is using them.
# ===============schedulerrequest=self.waiting[0]
computed_blocks=self.kv_cache_manager.get_computed_blocks(
request)
......
new_blocks=self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks)
ifnew_blocksisNone:
breakself.waiting.popleft()
#================kv_cache_managerdefallocate_slots(
self,
request: Request,
num_tokens: int,
computed_blocks: List[KVCacheBlock],
) ->Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request. Args: request: The request to allocate slots. num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. computed_blocks: The blocks that have already been computed. Returns: A list of new allocated blocks. """ifnum_tokens==0:
raiseValueError(
f"num_tokens must be greater than 0, got {num_tokens}")
# Touch the computed blocks to make sure they won't be evicted.num_evictable_computed_blocks=0ifself.enable_caching:
self._touch(computed_blocks)
# If a computed block of a request is an eviction candidate (in the# free queue and ref_cnt == 0), it cannot be counted as a free block# when allocating this request.num_evictable_computed_blocks=len(
[blkforblkincomputed_blocksifblk.ref_cnt==0])
else:
assertnotcomputed_blocks, (
"Computed blocks should be empty when ""prefix caching is disabled")
num_required_blocks=cdiv(num_tokens, self.block_size)
if (num_required_blocks>self.free_block_queue.num_free_blocks-num_evictable_computed_blocks):
# Cannot allocate new blocks.returnNone
Before submitting a new issue...
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
Your current environment
The output of `python collect_env.py`
Model Input Dumps
No response
🐛 Describe the bug
When V1 and prefix_caching are enabled, when allocating slots for requests in the waiting queue, the function immediately _touch(computed_blocks) at the beginning, incrementing their ref_cnt by 1. If free_block_queue.num_free_blocks is insufficient to allocate new blocks, the function returns None, causing the scheduling loop to break directly. At this point, the request remains in the waiting queue, but the ref_cnt of all blocks in computed_blocks has been incremented. This situation recurs, eventually leading to some blocks having a ref_cnt greater than 0 even though no request is using them.
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: