Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Sliding window for block manager v2 #4545

Merged
merged 58 commits into from
May 28, 2024

Conversation

mmoskal
Copy link
Contributor

@mmoskal mmoskal commented May 2, 2024

This implements sliding window in v2 block manager.

First commit comes from #3967 by @ruthe98, but the actual change was somewhat more complex including the concept of a null block.

It passes correctness tests with starcoder3b (the smallest model with sliding window I could find). The test does a bunch of assignments "x1 = 10; x2 = 33; ..." and then asks for value of one of them (which is outside the sliding window). If we tell it upfront which we are going to be looking for, then it answers correctly.

When using chunked prefill all the blocks for prompt are allocated immediately, while we could only allocate enough blocks for the chunk, and free any blocks that are no longer needed. After processing the prompt however, it does free the beginning of prompt at the first generation step.

This can be fixed later. The main problem with fixing this, is that if we're generating more than one sequence, they are all forked in BlockSpaceManagerV2.allocate(), but they really should only be forked after the prompt is fully computed. (see aborted attempt at fixing this)

CC @cadedaniel @ruthe98

FIX #3665
FIX #4057

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@cadedaniel
Copy link
Collaborator

  • FYI prefix caching + sliding window doesn't work with block manager v1 yet. so it's ok to prioritize that deprioritize that for this PR

    if enable_caching and sliding_window is not None:
    raise NotImplementedError(
    "Sliding window is not allowed with prefix caching enabled!")

  • personally I think it's OK to allocate entire prompt len for chunked prefill. it's a compute optimization, not a memory capacity optimization, after all

  • for testing, we can copy this test except run a sliding window model. it's important to go over the sliding window boundary (although in general the sliding window test coverage is pretty poor). I think one can mock the sliding window size to be much smaller for test convenience, FYI.

    def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,

@mmoskal mmoskal marked this pull request as ready for review May 3, 2024 20:57
@mmoskal
Copy link
Contributor Author

mmoskal commented May 3, 2024

@cadedaniel @rkooo567 @simon-mo this should be ready for review

@cadedaniel
Copy link
Collaborator

WTAL on Monday

@rkooo567
Copy link
Collaborator

rkooo567 commented May 4, 2024

QQ: is this the last feature that's needed before enabling block manager v2?

@cadedaniel
Copy link
Collaborator

See #4537

@rkooo567 rkooo567 self-assigned this May 4, 2024
@rkooo567
Copy link
Collaborator

rkooo567 commented May 4, 2024

Hmm maybe I can help getting cpu swapping done

@mmoskal
Copy link
Contributor Author

mmoskal commented May 6, 2024

@cadedaniel let me know if you need any more info from my side!

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a pass over the correctness test. I'll do a pass over the implementation now.

tests/core/block/e2e/test_correctness_sliding_window.py Outdated Show resolved Hide resolved
tests/core/block/e2e/test_correctness_sliding_window.py Outdated Show resolved Hide resolved
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
sampling_params = SamplingParams(
max_tokens=10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid these tests won't catch issues with the block mapping. E.g. I expect error to accumulate over many tokens before we see a significant divergence in attention scores. 10/4096 tokens is not very much, same for 128/4096 although it's better.

WDYT? Is my intuition right? Should we test with larger generation size? Another option is to patch sliding_window to be smaller (e.g. two blocks) so the impact of any error is larger. If we go with patching sliding_window we could even use one of the 68m models for a faster test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an issue with the block tables passed to single token decode, which was causing different output with 1024 tokens. I have now fixed that and bumped test size to 1024.

However, it's still slightly incorrect because the decode kernel does not support sliding window natively - the way it works now it just takes all the blocks passed in (up to seq_len). With v1 manager, the sliding window uses blocks in a "ring buffer" fashion, so this is not a problem. With the new block manager we need potentially to start attention computation in the middle of a block, otherwise we pay attention to a few tokens too many. It doesn't seem to affect this test though.

I have started fixing the decode kernel, but I think that should be a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the changes but want to say that yes, this problem is known and we should fix it eventually (awesome if you want to do it). let's get this PR in with good tests for where we're at and future PR can fix the decode kernel.

tests/core/block/e2e/test_correctness_sliding_window.py Outdated Show resolved Hide resolved
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any good way to assert the return prompt token len > 4k?

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thanks.

# however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290
# states that xformers and flash_attn have different ideas about the window
# size anyways
assert sum(cmp) > 0.7 * len(cmp)
Copy link
Collaborator

@cadedaniel cadedaniel May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the reason why this is so hard to test is because the semantics of blocks are changed in V2. E.g. we now have clear distinction between mutable and immutable blocks. so the kernels that previously would overwrite blocks (causing U.B. with copy-on-write in V1) now don't, but the downside is they capture additional context since we don't yet have masking.

The tests in this PR are not really good enough to catch correctness issues with the sliding window block mapping. The error tolerance in this test is very high and the unit test test_sliding_window only checks num consumed is correct.

That said, this is an improvement over the previous sliding window tests and I think we can merge and follow up later..

FWIW the way I'd test this is one/both of the following:

  • Modify this test to use block_size=1. This avoids the masking issue entirely and we should expect exact equality between v1 and v2 for most prompts.
  • Add a stronger unit test for block_manager_v2 or block_allocator that verifies the correct sliding window block mapping

@cadedaniel
Copy link
Collaborator

@mmoskal can you check the merge conflict? will merge after

@mmoskal
Copy link
Contributor Author

mmoskal commented May 24, 2024

Unfortunately, after merge the tests stopped working. The problem is that it's also the baseline tests (not using v2 block manager) that are not working. I'm getting the first token of the output correct, and the remaining tokens not complete gibberish but also not correct - so this is a problem with the decode phase or maybe kv cache entry arrangement?

I tried reverting all my changes in model_runner.py and the baseline tests still fail, which suggests it's something in the recent changes.

@mmoskal
Copy link
Contributor Author

mmoskal commented May 25, 2024

OK should work now - I fixed slot_mapping computation in _prepare_model_input

@cadedaniel
Copy link
Collaborator

as soon are tests are green let's merge. cc @rkooo567 for next week.

@mmoskal
Copy link
Contributor Author

mmoskal commented May 25, 2024

the failing tests don't look related to what I'm doing; I just tried pushing a random change to re-run

@mmoskal
Copy link
Contributor Author

mmoskal commented May 27, 2024

@rkooo567 @cadedaniel tests are green, please merge!

@rkooo567 rkooo567 merged commit d4f3985 into vllm-project:main May 28, 2024
63 checks passed
@rkooo567
Copy link
Collaborator

Thanks for the contribution! Should we next resume the paged attn PR?

@mmoskal
Copy link
Contributor Author

mmoskal commented May 28, 2024

Thank you for merging! I probably won't have time to work on the paged attn kernel PR in the next few weeks :/ The thing is, with this PR the paged attention is almost correct, it just pays attention to a few tokens too many.

dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 31, 2024
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 14, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants