diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 71cf68c8fb8..03634bd3913 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -20,9 +20,11 @@ from enum import Enum, auto from typing import Dict, List, Optional +import torch + from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.radix_cache import TreeNode +from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # This can prevent the server from being too conservative. @@ -32,6 +34,21 @@ os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096") ) +# Threshold for in-batch prefix cache. +# If a request has a matched prefix length (against existing cache) less than this value, +# the scheduler runs the in-batch prefix caching check for this request. +# If we set it to -1, it means we disable in-batch prefix caching. +IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int( + os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32") +) + +# Threshold for in-batch prefix cache. +# If a request has a matched prefix length (within the waiting queue) larger than this value, +# the scheduler deprioritizes this request +IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int( + os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32") +) + class SchedulePolicy: def __init__(self, policy: str, tree_cache: BasePrefixCache): @@ -42,6 +59,11 @@ def __init__(self, policy: str, tree_cache: BasePrefixCache): self.policy = policy self.tree_cache = tree_cache + # It is used to find the matching prefix for in-batch prefix caching. + self.waiting_queue_radix_tree = RadixCache( + req_to_token_pool=None, token_to_kv_pool=None, disable=False + ) + def calc_priority(self, waiting_queue: List[Req]): if len(waiting_queue) > 128 and self.policy == "lpm": # Turn off the expensive prefix matching and sorting when the #queue is large. @@ -52,17 +74,49 @@ def calc_priority(self, waiting_queue: List[Req]): # Compute matched prefix length prefix_computed = False if policy == "lpm" or policy == "dfs-weight": + # rid to deprioritize in the current run for in-batch prefix caching. + temporary_deprioritized = {} + self.waiting_queue_radix_tree.reset() + for r in waiting_queue: + prefix_ids = r.adjust_max_prefix_ids() + # NOTE: the prefix_indices must always be aligned with last_node r.prefix_indices, r.last_node = self.tree_cache.match_prefix( - rid=r.rid, key=r.adjust_max_prefix_ids() + rid=r.rid, key=prefix_ids ) + # NOTE(sang): This logic is for in-batch prefix caching; + # If there are more than 1 request that have small matching prefix from + # existing cache, but all those requests share the same prefix, we prefer + # to schedule only one of them so that we can increase the cache hit rate. + if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: + in_batch_matching_prefixes, _ = ( + self.waiting_queue_radix_tree.match_prefix( + rid=r.rid, key=prefix_ids + ) + ) + if ( + len(in_batch_matching_prefixes) + >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD + ): + temporary_deprioritized[r.rid] = r + else: + self.waiting_queue_radix_tree.insert( + prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) + ) + prefix_computed = True if policy == "lpm": # Longest Prefix Match - waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) + waiting_queue.sort( + key=lambda r: ( + -len(r.prefix_indices) + if r.rid not in temporary_deprioritized + else float("inf") + ) + ) elif policy == "fcfs": # first come first serve pass @@ -72,6 +126,7 @@ def calc_priority(self, waiting_queue: List[Req]): elif policy == "random": random.shuffle(waiting_queue) elif policy == "dfs-weight": + # Experimental policy based on custom weights last_node_to_reqs = defaultdict(list) for req in waiting_queue: last_node_to_reqs[req.last_node].append(req)