Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 17, 2024
1 parent fac93ee commit 441ece9
Showing 1 changed file with 58 additions and 3 deletions.
61 changes: 58 additions & 3 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from enum import Enum, auto
from typing import Dict, List, Optional

import torch

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode

# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
Expand All @@ -32,6 +34,21 @@
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)

# Threshold for in-batch prefix cache.
# If a request has a matched prefix length (against existing cache) less than this value,
# the scheduler runs the in-batch prefix caching check for this request.
# If we set it to -1, it means we disable in-batch prefix caching.
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int(
os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32")
)

# Threshold for in-batch prefix cache.
# If a request has a matched prefix length (within the waiting queue) larger than this value,
# the scheduler deprioritizes this request
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32")
)


class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
Expand All @@ -42,6 +59,11 @@ def __init__(self, policy: str, tree_cache: BasePrefixCache):
self.policy = policy
self.tree_cache = tree_cache

# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool=None, disable=False
)

def calc_priority(self, waiting_queue: List[Req]):
if len(waiting_queue) > 128 and self.policy == "lpm":
# Turn off the expensive prefix matching and sorting when the #queue is large.
Expand All @@ -52,17 +74,49 @@ def calc_priority(self, waiting_queue: List[Req]):
# Compute matched prefix length
prefix_computed = False
if policy == "lpm" or policy == "dfs-weight":
# rid to deprioritize in the current run for in-batch prefix caching.
temporary_deprioritized = {}
self.waiting_queue_radix_tree.reset()

for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()

# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
rid=r.rid, key=prefix_ids
)

# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids
)
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized[r.rid] = r
else:
self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
)

prefix_computed = True

if policy == "lpm":
# Longest Prefix Match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
waiting_queue.sort(
key=lambda r: (
-len(r.prefix_indices)
if r.rid not in temporary_deprioritized
else float("inf")
)
)
elif policy == "fcfs":
# first come first serve
pass
Expand All @@ -72,6 +126,7 @@ def calc_priority(self, waiting_queue: List[Req]):
elif policy == "random":
random.shuffle(waiting_queue)
elif policy == "dfs-weight":
# Experimental policy based on custom weights
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
Expand Down

0 comments on commit 441ece9

Please sign in to comment.