Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager #12003

Merged
merged 4 commits into from
Jan 15, 2025

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Jan 13, 2025

This pr is a preparation for hybrid memory allocator (Ref: #11382)

  1. Move the logic of full cache hit into KVCacheManager to prepare for supporting sliding window.
    For sliding window, if the sliding window size is 2 and block_size is 1, when computed blocks=[1,2,x,4,5], where x is an evicted block, the computed_block should be changed to [1,2] instead of [1,2,x,4]. It can't be achieved by the current computed_blocks.pop(), and will be implemented by introducing a new HybridKVCacheManager in a following pr.
  2. Move num_computed_tokens into KVCacheManager. When there are multiple block_tables (and maybe with different block_size to support Jamba), num_computed_tokens will be more complex than num_computed_tokens = len(computed_blocks) * self.block_size

CC @comaniac

Signed-off-by: Chen Zhang <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Comment on lines 97 to 105
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all blocks
# are cached, we need to recompute the last token. This have to be
# achieved by re-computing an entire block because allocate_slots()
# assumes num_computed_tokens is always a multiple of the block
# size. This limitation can potentially be removed in the future to
# slightly improve the performance. To achieve this, the last block
# is removed from the computed block_hashes.
block_hashes = ConstantList(block_hashes[:-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was making a similar change but @WoosukKwon prefers to have it in scheduler, because this limitation is in the model runner instead of kv cache manager.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any suggestion on supporting sliding window? I don't want to make the complex special handling of sliding window in scheduler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What handling you plan to add for sliding window in scheduler? In general we should try to be modulization. For this particular logic, I think we could just

class Scheduler:
    ...

    def maybe_recompute_last_block(self, computed_blocks, num_computed_tokens):
        ...

    def schedule(self):
        ...
        computed_blocks, num_computed_tokens = self.kv_cache_manager.get_computed_blocks(...)
        computed_blocks, num_computed_tokens = self.maybe_recompute_last_block(computed_blocks, num_computed_tokens)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interface is not very friendly to sliding window as the removal of last block needs a redo of get_computed_blocks. It works for me to discuss about which way is better after implementing the maybe_recompute_last_block for sliding window.
Is it ok to only change from computed_blocks = self.kv_cache_manager.get_computed_blocks(request) to computed_blocks, num_computed_tokens = self.kv_cache_manager.get_computed_blocks(request) in this pr?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'm ok with it.

Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
@heheda12345 heheda12345 changed the title [V1][Prefix Cache] Move the logic of full cache hit and num_computed_tokens into KVCacheManager [V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager Jan 14, 2025
@heheda12345
Copy link
Collaborator Author

@comaniac I've modified this pr to only changing num_computed_tokens. Can you help to take a look?

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM

Comment on lines +72 to +73
def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the docstring about the return type.

Signed-off-by: Chen Zhang <[email protected]>
@heheda12345
Copy link
Collaborator Author

Thanks for pointing it out. I've updated the docstring.

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 15, 2025
@comaniac comaniac enabled auto-merge (squash) January 15, 2025 03:54
@comaniac comaniac merged commit 994fc65 into vllm-project:main Jan 15, 2025
63 of 64 checks passed
ice-tong pushed a commit to ice-tong/vllm that referenced this pull request Jan 18, 2025
joennlae pushed a commit to 44ai-labs/vllm that referenced this pull request Jan 19, 2025
joennlae pushed a commit to 44ai-labs/vllm that referenced this pull request Jan 19, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Jan 21, 2025
abmfy pushed a commit to abmfy/vllm-flashinfer that referenced this pull request Jan 24, 2025
abmfy pushed a commit to abmfy/vllm-flashinfer that referenced this pull request Jan 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants