From 56198b45d9712bdbb161d226f94b4647738d33f5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 16 Dec 2024 18:49:02 -0800 Subject: [PATCH] Add a benchmark script for in-batch prefix caching (#2494) --- .../bench_in_batch_prefix.py | 130 ++++++++++++++++++ python/sglang/srt/managers/schedule_policy.py | 86 ++++++------ 2 files changed, 177 insertions(+), 39 deletions(-) create mode 100644 benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py diff --git a/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py b/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py new file mode 100644 index 00000000000..86648e5ff17 --- /dev/null +++ b/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py @@ -0,0 +1,130 @@ +# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance. +# +# Launch a server: +# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning + +import random +import string +import time + +from tqdm import tqdm +from transformers import AutoTokenizer + +import sglang as sgl +from sglang import set_default_backend +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint + + +def generate_random_string(token_length: int) -> str: + random_string = "".join( + random.choices(string.ascii_letters + string.digits, k=token_length * 100) + ) + tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[ + :token_length + ] + + if len(tokenized_output) < token_length: + tokenized_output = tokenized_output + [tokenizer.pad_token_id] * ( + token_length - len(tokenized_output) + ) + + decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False) + return decoded_string + + +def generate_unique_prefix(base_text, index): + return str(index) + base_text[len(str(index)) :] + + +@sgl.function +def text_qa(s, question, gen_len): + s += "Q: " + question + "\n" + s += "A:" + sgl.gen("answer", stop="\n", temperature=0, max_tokens=gen_len) + + +def prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length): + base_prefix = generate_random_string(prefix_length) + + tot_input_len = 0 + all_prompts = [] + for i in tqdm(range(num_prefix), desc="prepare prompts"): + unique_prefix = generate_unique_prefix(base_prefix, i) + prompt_list = [] + for j in range(num_samples_per_prefix): + suffix = generate_random_string(suffix_length) + prompt = unique_prefix + suffix + prompt_list.append(prompt) + tot_input_len += len(tokenizer.encode(prompt)) + all_prompts.append(prompt_list) + return all_prompts, tot_input_len + + +def test_batch_by_batch(all_prompts, gen_len): + backend.flush_cache() + + tot_time = 0 + for i in range(len(all_prompts)): + tic = time.time() + text_qa.run_batch( + list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))), + ) + tot_time += time.time() - tic + + return tot_time + + +def test_batch_by_batch_with_hint(all_prompts, gen_len): + backend.flush_cache() + + tot_time = 0 + for i in range(len(all_prompts)): + tic = time.time() + # Send a hint to cache the prefix + text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len]))) + # Send the batch + text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i])))) + + tot_time += time.time() - tic + + return tot_time + + +def test_send_all(all_prompts, gen_len): + backend.flush_cache() + + all_prompts = [x for prompt_list in all_prompts for x in prompt_list] + + tic = time.time() + text_qa.run_batch( + list(zip(all_prompts, [gen_len] * len(all_prompts))), + ) + tot_time = time.time() - tic + + return tot_time + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + backend = RuntimeEndpoint("http://127.0.0.1:30000") + set_default_backend(backend) + + random.seed(0) + num_prefix = 10 + num_samples_per_prefix = 32 + prefix_length = 1024 + suffix_length = 128 + gen_len = 1 + all_prompts, tot_input_len = prepare_prompts( + num_prefix, num_samples_per_prefix, prefix_length, suffix_length + ) + + print(f"Total input token length: {tot_input_len}\n") + + cost = test_batch_by_batch(all_prompts, gen_len) + print(f"Latency of test_batch_by_batch : {cost:.4f} s\n") + + cost = test_batch_by_batch_with_hint(all_prompts, gen_len) + print(f"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\n") + + cost = test_send_all(all_prompts, gen_len) + print(f"Latency of test_send_all : {cost:.4f} s\n") diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 1bb872fdf71..70af61ddb6f 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -34,11 +34,19 @@ os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096") ) -# The threshold to apply in-batch prefix caching. -# If we use too small value, in-batch prefix caching cannot be used. E.g., -# imagine "the" prefix. -IN_BATCH_PREFIX_CACHING_THRESHOLD = int( - os.environ.get("SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD", "32") +# Threshold for in-batch prefix cache. +# If a request has a matched prefix length (against existing cache) less than this value, +# the scheduler runs the in-batch prefix caching check for this request. +# If we set it to -1, it means we disable in-batch prefix caching. +IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int( + os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32") +) + +# Threshold for in-batch prefix cache. +# If a request has a matched prefix length (within the waiting queue) larger than this value, +# the scheduler deprioritizes this request +IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int( + os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32") ) @@ -51,6 +59,11 @@ def __init__(self, policy: str, tree_cache: BasePrefixCache): self.policy = policy self.tree_cache = tree_cache + # It is used to find the matching prefix for in-batch prefix caching. + self.waiting_queue_radix_tree = RadixCache( + req_to_token_pool=None, token_to_kv_pool=None, disable=False + ) + def calc_priority(self, waiting_queue: List[Req]): if len(waiting_queue) > 128 and self.policy == "lpm": # Turn off the expensive prefix matching and sorting when the #queue is large. @@ -60,50 +73,54 @@ def calc_priority(self, waiting_queue: List[Req]): # Compute matched prefix length prefix_computed = False - # rid to deprioritize in the current run. - temporary_deprioritized = {} if policy == "lpm" or policy == "dfs-weight": - # It is used to find the matching prefix for in-batch prefix caching. - temp_radix = RadixCache(None, None, False) + # rid to deprioritize in the current run for in-batch prefix caching. + temporary_deprioritized = set() + self.waiting_queue_radix_tree.reset() + for r in waiting_queue: prefix_ids = r.adjust_max_prefix_ids() + # NOTE: the prefix_indices must always be aligned with last_node r.prefix_indices, r.last_node = self.tree_cache.match_prefix( rid=r.rid, key=prefix_ids ) - # NOTE(sang): This logic is for In-batch prefix caching; + # NOTE(sang): This logic is for in-batch prefix caching; # If there are more than 1 request that have small matching prefix from # existing cache, but all those requests share the same prefix, we prefer # to schedule only one of them so that we can increase the cache hit rate. - # We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small + # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small # threshold means we cannot use in-batch prefix caching for short prefixes. - # It is kind of common when the engine is long running (e.g., imagine "the"). - if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_THRESHOLD: - in_batch_matching_prefixes, _ = temp_radix.match_prefix( - rid=r.rid, key=prefix_ids + # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). + if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: + in_batch_matching_prefixes, _ = ( + self.waiting_queue_radix_tree.match_prefix( + rid=r.rid, key=prefix_ids + ) ) if ( len(in_batch_matching_prefixes) - >= IN_BATCH_PREFIX_CACHING_THRESHOLD + >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD ): - temporary_deprioritized[r.rid] = r + temporary_deprioritized.add(r.rid) else: - temp_radix.insert(prefix_ids, torch.tensor(prefix_ids)) + # Insert with a dummy key + self.waiting_queue_radix_tree.insert( + prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) + ) prefix_computed = True if policy == "lpm": # Longest Prefix Match - def get_priority(r: Req): - score = 0 - if r.rid in temporary_deprioritized: - score = float("inf") - else: - score = -len(r.prefix_indices) - return score - - waiting_queue.sort(key=get_priority) + waiting_queue.sort( + key=lambda r: ( + -len(r.prefix_indices) + if r.rid not in temporary_deprioritized + else float("inf") + ) + ) elif policy == "fcfs": # first come first serve pass @@ -113,11 +130,11 @@ def get_priority(r: Req): elif policy == "random": random.shuffle(waiting_queue) elif policy == "dfs-weight": + # Experimental policy based on custom weights last_node_to_reqs = defaultdict(list) for req in waiting_queue: last_node_to_reqs[req.last_node].append(req) - # node -> # of requests for that node. node_to_weight = defaultdict(int) for node in last_node_to_reqs: node_to_weight[node] = len(last_node_to_reqs[node]) @@ -129,9 +146,7 @@ def get_priority(r: Req): node_to_weight, last_node_to_reqs, waiting_queue, - temporary_deprioritized, ) - waiting_queue.extend(temporary_deprioritized.values()) else: raise ValueError(f"Unknown schedule_policy: {policy=}") @@ -148,19 +163,12 @@ def get_dfs_priority( node_to_priority: Dict[TreeNode, int], last_node_to_reqs: Dict[TreeNode, List[Req]], q: List, - temporary_deprioritized: Dict[str, Req], ): childs = [child for child in cur_node.children.values()] childs.sort(key=lambda x: -node_to_priority[x]) for child in childs: - self.get_dfs_priority( - child, node_to_priority, last_node_to_reqs, q, temporary_deprioritized - ) - - for req in last_node_to_reqs[cur_node]: - if req.rid in temporary_deprioritized: - continue - q.append(req) + self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) + q.extend(last_node_to_reqs[cur_node]) class AddReqResult(Enum):