Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: BatchLLM for better shared prefix utilizing in offline scenarios #12080

Open
1 task done
xinji1 opened this issue Jan 15, 2025 · 8 comments
Open
1 task done

[RFC]: BatchLLM for better shared prefix utilizing in offline scenarios #12080

xinji1 opened this issue Jan 15, 2025 · 8 comments
Labels

Comments

@xinji1
Copy link

xinji1 commented Jan 15, 2025

Motivation.

This request is mainly for offline inference scenarios , based on the paper BatchLLM

TL; DR: Currently, vllm performs implicit (or just_in_time) shared prefix identifying and metadata collecting, and then performs cascade attention when there's one single shared prefix for all requests, according to the PR #11635. However it does not utilize the shared prefix fully under offline scenarios, where there're a lot of requests with different shared prefixes. This PR tries to alleviate the following pain points of vllm's inference .

  • Point 1: Currently vllm's inference with prefix-caching and cascade attention cannot gather all requests with the same common prefix together (it's essential since all query tokens with the same common prefix have to be treated as if they are from the same request, for the attention calculation)

  • Point 2: Under offline scenarios, it's not necessary to perform implicit shared prefix identifying since all requests are ready before the inference starts. Implicit prefix caching is not the best way to manage the kv-cache of shared prefix tokens.

  • Point 3: When it comes to vllm's cascade attention, it cannot support different common prefixes for different requests in one batch.

How BatchLLM tries to alleviate them

  • For the Point 1, one simple and easy way is to use a sorted() function (like python.sorted()) to sort all samples in a dataset before the inference starts. Here we try to gather the requests with the same prefix together, identify the shared prefixes of different requests explicitly, and enlarge the shared prefix as much as possible.

  • For the Point 2, we're trying to introduce the concept of "prefix-sharing group", where a mini-set of requests share the same common prefix. If the original requests look like:

a,b,c,d,e, f,g,h...
{---X---},{--Y1--}

a,b,c,d,e, i,j,k...
{---X---},{--Y2--}

...

here's how a prefix-sharing group looks like:

List[List, List[List]]: [X, [Y1, Y2,...]]

where we put the common prefix X as the first element, followed by a list of all the non-shared context Y1 & Y2. If there're 2 requests sharing the same prefix, we'll separate them into the common prefix and the other 2 non-shared context. In this way, vLLM will handle 1+2=3 requests, meaning that BatchLLM will inference and save the kv-cache of the common prefix first ( as a single request without any decoding operation), then generate tokens according to the other 2 non-shared context/requests. Finally, when the inference of all requests in a prefix-sharing group is done, the kv-cache of the common part would be released.

  • For the point 3, based on the above changes, it's much easier to collect the meta-info related to cascade attention. However current flash-attn kernels cannot support the cases, when there're different common prefixes for different requests in one single batch. We've achieve a triton version for the common/distinct/merge_2 kernels (like the PR #11635), showing good performance even with some extra Triton overheads.

See below for the performance improvement

  • model: llama-3.1-8b
  • GPU: single A100
  • setting:
    • no cuda_graph & multi-step decoding
    • for the vllm baseline, chunked-prefill(max_tokens in one batch is 2048) & prefix-caching are both enabled, and cascade inference v1 is enabled default after the PR [V1] Implement Cascade Attention #11635 (Found that it needs the "VLLM_USE_V1", add the experiment too.)
    • 6400 requests, each 400 of them share the same common prefix
    • each request have 2200 tokens, while the length of common prefix is 2000, and the length of non-shared context is 200.
  • result:
setting throughput
vllm + chunked-prefill + prefix-caching, after v0.6.6 post1 (commit: 5340a30) 6.62
vllm + chunked-prefill + prefix-caching + python.sorted(), after v0.6.6 post1 13.17
vllm + chunked-prefill + prefix-caching + python.sorted(), after v0.6.6 post1, VLLM_USE_V1=1 10.78
Our implementation based on v0.6.4 18.01
  • After changing to different sharing degree & different length settings of shared prefix (the following tests are based on vLLM v0.6.4 and SGLang v0.4.1):

Image

Image

Proposed Change.

High level

  1. A preprocess part for the building of "prefix-sharing group", where BatchLLM will gather the requests with the same prefix together, identify the shared prefixes of different requests explicitly, and enlarge the shared prefix as much as possible . And we put the preprocess codes in llm.py.
  2. A new manager for managing the request of shared prefix/ non-shared context. For example, release all the blocks of shared prefix after all requests in one prefix-sharing group are inferenced.
  3. A new backend based on FlashAttnBackend, according to the reviewer here. Here we need to collect the meta-info of different prefix-sharing groups and perform attention calculation. Currently we use the triton kernels we've implemented.

Feedback Period.

No response

CC List.

@WoosukKwon
@comaniac
@pavanimajety

Any Other Things.

  • the script used in Motivation

the cmd is python vllm_baseline.py --model_path /workspace/llama3_8b --request_num 6400 --context_len 2000 --prompt_len 200 --generate_len 100 --sharing_degree 16 --rand --use_prefix_caching --chunk

import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import time
import datetime
from pandas import read_table
import math
import argparse

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import random
random.seed(123)
torch.manual_seed(123)

def parse_args():
    parser = argparse.ArgumentParser(description='vLLM performance test')
    parser.add_argument('--use_cuda_graph', type=bool, default=False,
                        help='use_cuda_graph')
    parser.add_argument('--model_path', type=str, default='/workspace/llama-3.1-8b',
                        help='model_path')
    parser.add_argument('--request_num', type=int, default=6400,
                        help='request_num')
    parser.add_argument('--context_len', type=int, default=2000,
                        help='context_len')
    parser.add_argument('--prompt_len', type=int, default=200,
                        help='prompt_len')
    parser.add_argument('--generate_len', type=int, default=100,
                        help='generate_len')
    parser.add_argument('--use_prefix_caching', action="store_true",
                        help='use_prefix_caching')
    parser.add_argument('--chunk', action="store_true",
                        help='chunk')
    parser.add_argument('--rand', action="store_true",
                        help='whether the input is randomly shuffled')
    parser.add_argument('--sharing_degree', type=int, default=2,
                        help='how many requests share the same prefix')
    parser.add_argument('--random_generate', action="store_true",
                        help='not ignore eos')

    return parser.parse_args()


def prepare_tokens(tokenizer, context_len, prompt_len, group_num, batch_size):
    share = []
    all_t = []
    group_idx = []
    for i in range(group_num):
        context_this_group = torch.randint(1, 20000, (context_len,))
        share.append(context_this_group.tolist())
        for j in range(batch_size):
            prompt_this_request = torch.randint(1, 20000, (prompt_len,))
            all_t.append(
                torch.concat((context_this_group[0:context_len], prompt_this_request[0:prompt_len]), 0).tolist())
            group_idx.append(i)
    return all_t, group_num, group_num * batch_size, share, group_idx


def build_pipeline(engine_path, prefix_caching, chunk=False):
    kwargs = {}
    if "70B" in engine_path or "70b" in engine_path:
        print("70B ")
        kwargs  ={"quantization":"gptq", 
                  
                  "max_model_len":8192
                  }
        # os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = 1

    if chunk:
        pipe = LLM(model=engine_path,
                dtype= "float16",
                tensor_parallel_size=1,
                enable_prefix_caching=prefix_caching,
                enforce_eager= True,
                disable_sliding_window=True,
                enable_chunked_prefill=True,
                max_num_batched_tokens=2048,
                **kwargs,
                )
    else:
        if "llama" not in engine_path:
            kwargs = {
                **kwargs,
                "max_num_batched_tokens":32768,
            }
        pipe = LLM(model=engine_path,
                dtype= "float16",
               
               tensor_parallel_size=1,
               enable_prefix_caching=prefix_caching,
               enforce_eager=True,
               disable_sliding_window=True,
               **kwargs,
               )
    return pipe


def run_pipeline(input_tokens, pipe, group_id_list=None, max_tokens=100, random_generate=False):
    sampling_params = SamplingParams(temperature=0.01,
                                     top_p=0.1,
                                     max_tokens=max_tokens,
                                     ignore_eos=(not random_generate)
                                     )
    t1 = time.time()
    output = pipe.generate(prompt_token_ids=input_tokens,
                           sampling_params=sampling_params
                           )
    t2 = time.time()
    return output, t2 - t1


def prepare_baseline_caption_token(tokenizer, file_path):
    # Read the caption file
    f = read_table(file_path)
    f_shape = f.shape
    tokens = []
    for i in range(f_shape[0]):
        token1 = tokenizer.encode(f.iloc[i, 7], add_special_tokens=False, return_tensors='pt')
        token2 = tokenizer.encode(f.iloc[i, 8], add_special_tokens=False, return_tensors='pt')
        tokens.append(torch.concat((token1[0, :], token2[0, :]), 0).tolist())
    return tokens, f_shape[0], f_shape[0]


args = parse_args()
batch_size_list = [int(args.sharing_degree)]
engine_path = args.model_path
request_num = args.request_num
context_len = args.context_len
prompt_len = args.prompt_len
generate_len = args.generate_len
use_cuda_graph = args.use_cuda_graph
prefix_caching = args.use_prefix_caching
tokenizer = AutoTokenizer.from_pretrained(engine_path)

file_name = f'baseline_066_random_{args.rand}_sd_{args.sharing_degree}_baseline_{request_num}_{context_len}_{prompt_len}_{generate_len}_chunk_{args.chunk}_{args.model_path}_prefix_caching_{str(prefix_caching)}.txt'.replace("/","_")
f = open(file_name, 'w')
print(f'////////////////////////////', file=f)
print(f' {datetime.datetime.now()} vllm performace test', file=f)

pipe = build_pipeline(engine_path, prefix_caching, args.chunk)
print(f'engine_path:{engine_path},prefix_caching:{prefix_caching}', file=f)
print(f'group_num\tprompt_num\ttime\tthroughput', file=f)


# benchmark test
for batch_size in batch_size_list:
    group_num = request_num // batch_size
    input_tokens, group_num, prompt_num, share_tokens, group_idx = prepare_tokens(tokenizer, context_len, prompt_len,
                                                                                  group_num, batch_size)
    if args.rand:
        random.shuffle(input_tokens)
    
    gen_share_time = 0

    final_output, gen_time = run_pipeline(input_tokens, pipe, group_idx, generate_len, args.random_generate)


    print(f'{group_num:9}\t{prompt_num:10}\t{gen_time:8.2f}\t{prompt_num / (gen_time):10.2f}')
    print(f'{group_num:9}\t{prompt_num:10}\t{gen_time:8.2f}\t{prompt_num / (gen_time):10.2f}', file=f)


Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@xinji1 xinji1 added the RFC label Jan 15, 2025
@comaniac
Copy link
Collaborator

Thanks for the RFC and this is indeed an useful feature for batch inference with common prefix. For the proposed changes:

  1. A preprocess part for the building of "prefix-sharing group", where BatchLLM will gather the requests with the same prefix together, identify the shared prefixes of different requests explicitly, and enlarge the shared prefix as much as possible . And we put the preprocess codes in llm.py.

I feel we could wrap a new engine called BatchLLM to cover all offline batching use cases. BatchLLM could wrap AsyncLLM with the preprocessing logic you mentioned. It could then be initialized by LLM based on a certain configuration. This design should have better modularization, and could consolidate all possible batch inference optimizations in the future.

  1. A new manager for managing the request of shared prefix/ non-shared context. For example, release all the blocks of shared prefix after all requests in one sharing prefix group are inferenced.

I'm not sure if we need an additional manager. The current kv-cache manager with prefix caching should be able to achieve this goal. It would be ideal if we don't need to introduce another layer of complexity if we don't have to. If your desired functionalities can be achieved by slightly improving the current kv-cache manager, then we should consider that first.

  1. A new backend based on FlashAttnBackend, according to the reviewer Adding cascade inference to vLLM #10011 (comment). Here we need to collect the meta-info of different prefix-sharing groups and perform attention calculation. Currently we use the triton kernels we've implemented.

This should be able to be achieved using FlashInfer?

@xinji1
Copy link
Author

xinji1 commented Jan 16, 2025

Hi @comaniac , Thanks for your comments!

  1. Currently we implement BatchLLM based on LLMEngine. We're not quite sure that whether the scheduling of prefix sharing group could be compatible with AsyncLLM. Let me illustrate the scheduling here.

When dealing with multiple requests, some of which share a common prefix, it is important to maintain the order of outputs (aligned with the order of inputs)

Consider the following requests:

  • A: [x, y1]
  • B: [x, y2]
  • C: [c]

Requests A and B share the same prefix x, while C is a single request.

BatchLLM builds a prefix-sharing group for A and B, extracting the shared prefix x as a separate request. This results in four requests in total:

  • Shared prefix request: [x]
  • Non-shared context requests: [y1], [y2]
  • Single request: [c] (without any shared prefix)

Besides the queue self.waiting in the file vllm/core/scheduler.py, we need another two queue self.non_shared_ready and self.non_shared_waiting. Before BatchLLM inference starts:

  • The shared prefix request and single request are placed into the self.waiting queue of the scheduler. For example, self.waiting would be [[x], [c]]
  • All non-shared context requests are placed into the self.non_shared_waiting queue. For example, self.non_shared_waiting would be [[y1], [y2]]

When inference for one shared prefix request is completed:

  • Non-shared context requests for this prefix-caching group are moved to the self.non_shared_ready queue of the scheduler.
  • All requests in the self.non_shared_ready queue at this step are prioritized over requests in the self.waiting queue.
    Like:
        # self.waiting:            [  [c]       ] -> [ [c]        ] -> [ [y1], [y2], [c]]
        #                                                               |
        #                                                               |
        # self.non_shared_ready:   [            ] -> [ [y1], [y2] ] -> []       
        #                                             |
        #                                             |
        # self.non_shared_waiting: [ [y1], [y2] ] -> [            ] -> []

We're not quite sure that whether the scheduling is ok when it comes to async scenarios. The test is ongoing.

  1. I think it's ok. What the new manager do is to maintain some dictionary variables of the request_id mappings between shared prefix requests and non-share context requests. Some block-free functions are also involved.

  2. Some slight changes over the kernels of FlashAttention/FlashInfer are needed(just some params, not the computation logic ). The point is that not all requests in one batch need the cascade attention(like the request without any shared prefix). Take the kernel in vllm.vllm_flash_attn.flash_attn_varlen_func as the example. We need to change the param cu_seqlens_q into seqlens_q and q_start_loc, so that some requests could be skipped when it comes to the attention calculation with shared prefix. (updated on 01/17/2025) For these requests without any shared prefix, we can let the length of kv (shared-prefix) become 0. This may result in some extra overheads for launching the redundant blocks, but it should be ok comparing with the performance improvement.

@xinji1
Copy link
Author

xinji1 commented Jan 16, 2025

cc @youkaichao @zhuohan123 @simon-mo for more comments.

@comaniac
Copy link
Collaborator

I understand what you did to the requests.

  1. The logic of figuring out the shared prefix and generating common prefix request can be done outside of the schedule.
  2. I'd highly recommended making your approach compatible with async, because continuous batching also beneficial in offline batching.
  3. I'm a bit concern about adding more queues to the scheduler because it makes the scheduler more complicate and not all workloads need this optimization. In this case, it's ideal to figure out a way to isolate this logic and only use it by configuration. There are two options I could think of atm:
    1. Make the scheduler pluggable, so that we could have a "GroupPrefixSharedScheduler", but this needs to refactor the scheduler significantly to extract common logic and design a unified interface. In long term this might be the right way to go because I could imagine other optimizations (e.g., inference time reasoning) may also want to introduce more scheduler queues to schedule requests in a certain way. However, I could also imagine this would take a long time.
    2. Control the request order outside the engine. This cannot achieve the optimal performance as you benchmark, but should still be decent (not sure how good or how bad it would be tho). This may be a short term solution that can be done easily first.

For the attention kernel, @WoosukKwon should comment more, but in general it would be better if we could avoid introducing an attention backend for a particular functionality, because here will come with lots of follow up issues, such as whether this backend supports other hardware, and whether this backend supports OO and XX features (e.g., chunked prefill, XXAttention, etc).

@xinji1
Copy link
Author

xinji1 commented Jan 17, 2025

@comaniac Thanks for your advice!

I'd highly recommended making your approach compatible with async, because continuous batching also beneficial in offline batching.

You mean the application for online scenarios or offline scenarios? If offline, i think our implementation with chunked-prefill supported is not related to any inference of "async"; If online, one potential scenario is that a batch of requests with shared prefix come at the same time, then we could enable BatchLLM to perform the building process, otherwise it's quite difficult.

I'm a bit concern about adding more queues to the scheduler because it makes the scheduler more complicate and not all workloads need this optimization. In this case, it's ideal to figure out a way to isolate this logic and only use it by configuration. There are two options I could think of atm:
Make the scheduler pluggable, so that we could have a "GroupPrefixSharedScheduler", but this needs to refactor the scheduler significantly to extract common logic and design a unified interface. In long term this might be the right way to go because I could imagine other optimizations (e.g., inference time reasoning) may also want to introduce more scheduler queues to schedule requests in a certain way. However, I could also imagine this would take a long time.
Control the request order outside the engine. This cannot achieve the optimal performance as you benchmark, but should still be decent (not sure how good or how bad it would be tho). This may be a short term solution that can be done easily first.

We'll test

  • A: only use the queue "self.waiting" to keep shared-prefix/non-shared context/single requests, to see if it could maintain the output order(since some operations like "swapping the shared-prefix request" may affect the order.
  • B: If the issues cannot be alleviated by the method A, we'll try to control the request order outside the engine. BTW, Could the way like changing the self.waiting of the scheduler in the other classes (in prefix-caching manager, for example) be regarded as "outside the engine"? If not, is there any implementation controlling the request order outside the engine that we could take as a reference?

For the backend part, now our implementation is based on FlashAttentionBackend, supporting chunked-prefill mode. Further supports such as CUDA_graph & multi-step decoding won't be involved in the first version of BatchLLM.

@heheda12345
Copy link
Collaborator

How to extend this feature to multi-modal inputs & LORA? LORA should be easy. But it will be difficult to get the common prefix length of multi-modal inputs before MultiModalProcessor.

@comaniac
Copy link
Collaborator

Could the way like changing the self.waiting of the scheduler in the other classes (in prefix-caching manager, for example) be regarded as "outside the engine"? If not, is there any implementation controlling the request order outside the engine that we could take as a reference?

We have priority scheduling (in v0, not v1), but I'm not sure if you could leverage that entirely to achieve your goal. Other than that, there's no other way to add another queue outside the scheduler, because the scheduler is supposed to in charge of "scheduling" requests.

@xinji1
Copy link
Author

xinji1 commented Jan 21, 2025

How to extend this feature to multi-modal inputs & LORA? LORA should be easy. But it will be difficult to get the common prefix length of multi-modal inputs before MultiModalProcessor.

@heheda12345 For now we have to get all tokens of inputs in the beginning so that we could identify the common prefix.

We have priority scheduling (in v0, not v1), but I'm not sure if you could leverage that entirely to achieve your goal. Other than that, there's no other way to add another queue outside the scheduler, because the scheduler is supposed to in charge of "scheduling" requests.

@comaniac After measuring the efforts, It's quite difficult for us to run BatchLLM without extra queues keeping the non-shared requests. Under the chunked-prefill scenarios,

Shared prefix request: [x]
Non-shared context requests: [y1], [y2]
Single request: [c] (without any shared prefix)

The self.waiting will be [[x], [y1], [y2], [c]]. we always want to fill the batch with shared-prefix and/or single requests, however it needs to skip the non-shared context requests to gather the request [x], [c] together. We hold the view that :

  • Making the scheduler pluggable (as you mentioned) could be the best way to implement the scheduling of BatchLLM, while we'll take it as the TODO in the future.

  • Another way to alleviate it :extract all shared-prefix requests and get their kv-caches first, then non-shared context / single requests.

    • Pros: it won't add any extra queue.
    • Cons: Not all kv-cache of shared-prefix requests would be kept in the block table because of the limited gpu memory, so how many shared-prefix requests should be extract first should be taken into account. And the strategy of releasing kv-cache of shared-prefix requests is sub-optimal.

We'll create our PR asap for further advice

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants