-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: BatchLLM for better shared prefix utilizing in offline scenarios #12080
Comments
Thanks for the RFC and this is indeed an useful feature for batch inference with common prefix. For the proposed changes:
I feel we could wrap a new engine called
I'm not sure if we need an additional manager. The current kv-cache manager with prefix caching should be able to achieve this goal. It would be ideal if we don't need to introduce another layer of complexity if we don't have to. If your desired functionalities can be achieved by slightly improving the current kv-cache manager, then we should consider that first.
This should be able to be achieved using FlashInfer? |
Hi @comaniac , Thanks for your comments!
When dealing with multiple requests, some of which share a common prefix, it is important to maintain the order of outputs (aligned with the order of inputs) Consider the following requests:
Requests BatchLLM builds a prefix-sharing group for
Besides the queue
When inference for one shared prefix request is completed:
We're not quite sure that whether the scheduling is ok when it comes to async scenarios. The test is ongoing.
|
cc @youkaichao @zhuohan123 @simon-mo for more comments. |
I understand what you did to the requests.
For the attention kernel, @WoosukKwon should comment more, but in general it would be better if we could avoid introducing an attention backend for a particular functionality, because here will come with lots of follow up issues, such as whether this backend supports other hardware, and whether this backend supports OO and XX features (e.g., chunked prefill, XXAttention, etc). |
@comaniac Thanks for your advice!
You mean the application for
We'll test
For the backend part, now our implementation is based on |
How to extend this feature to multi-modal inputs & LORA? LORA should be easy. But it will be difficult to get the common prefix length of multi-modal inputs before |
We have priority scheduling (in v0, not v1), but I'm not sure if you could leverage that entirely to achieve your goal. Other than that, there's no other way to add another queue outside the scheduler, because the scheduler is supposed to in charge of "scheduling" requests. |
@heheda12345 For now we have to get all tokens of inputs in the beginning so that we could identify the common prefix.
@comaniac After measuring the efforts, It's quite difficult for us to run BatchLLM without extra queues keeping the non-shared requests. Under the chunked-prefill scenarios,
The
We'll create our PR asap for further advice |
Motivation.
This request is mainly for offline inference scenarios , based on the paper BatchLLM
TL; DR: Currently, vllm performs implicit (or just_in_time) shared prefix identifying and metadata collecting, and then performs cascade attention when there's one single shared prefix for all requests, according to the PR #11635. However it does not utilize the shared prefix fully under offline scenarios, where there're a lot of requests with different shared prefixes. This PR tries to alleviate the following pain points of vllm's inference .
Point 1: Currently vllm's inference with prefix-caching and cascade attention cannot gather all requests with the same common prefix together (it's essential since all query tokens with the same common prefix have to be treated as if they are from the same request, for the attention calculation)
Point 2: Under offline scenarios, it's not necessary to perform implicit shared prefix identifying since all requests are ready before the inference starts. Implicit prefix caching is not the best way to manage the kv-cache of shared prefix tokens.
Point 3: When it comes to vllm's cascade attention, it cannot support different common prefixes for different requests in one batch.
How BatchLLM tries to alleviate them
For the Point 1, one simple and easy way is to use a sorted() function (like python.sorted()) to sort all samples in a dataset before the inference starts. Here we try to gather the requests with the same prefix together, identify the shared prefixes of different requests explicitly, and enlarge the shared prefix as much as possible.
For the Point 2, we're trying to introduce the concept of "prefix-sharing group", where a mini-set of requests share the same common prefix. If the original requests look like:
here's how a prefix-sharing group looks like:
where we put the common prefix X as the first element, followed by a list of all the non-shared context Y1 & Y2. If there're 2 requests sharing the same prefix, we'll separate them into the common prefix and the other 2 non-shared context. In this way, vLLM will handle 1+2=3 requests, meaning that BatchLLM will inference and save the kv-cache of the common prefix first ( as a single request without any decoding operation), then generate tokens according to the other 2 non-shared context/requests. Finally, when the inference of all requests in a prefix-sharing group is done, the kv-cache of the common part would be released.
See below for the performance improvement
cascade inference v1 is enabled default after the PR [V1] Implement Cascade Attention #11635(Found that it needs the "VLLM_USE_V1", add the experiment too.)Proposed Change.
High level
llm.py
.FlashAttnBackend
, according to the reviewer here. Here we need to collect the meta-info of different prefix-sharing groups and perform attention calculation. Currently we use the triton kernels we've implemented.Feedback Period.
No response
CC List.
@WoosukKwon
@comaniac
@pavanimajety
Any Other Things.
the cmd is
python vllm_baseline.py --model_path /workspace/llama3_8b --request_num 6400 --context_len 2000 --prompt_len 200 --generate_len 100 --sharing_degree 16 --rand --use_prefix_caching --chunk
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: