Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] in batch prefix caching by delay scheduling #2442

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/lang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def flush_cache(self):
self.base_url + "/flush_cache",
api_key=self.api_key,
verify=self.verify,
method="POST",
)
self._assert_success(res)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(

# Prefix info
self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
self.extend_input_len = 0
self.last_node = None

Expand Down Expand Up @@ -316,6 +317,7 @@ def finished(self) -> bool:
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache.
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
Expand Down
65 changes: 58 additions & 7 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from enum import Enum, auto
from typing import Dict, List, Optional

import torch

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode

# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
Expand All @@ -32,6 +34,13 @@
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)

# The threshold to apply in-batch prefix caching.
# If we use too small value, in-batch prefix caching cannot be used. E.g.,
# imagine "the" prefix.
IN_BATCH_PREFIX_CACHING_THRESHOLD = int(
os.environ.get("SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD", "32")
)


class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
Expand All @@ -51,18 +60,50 @@ def calc_priority(self, waiting_queue: List[Req]):

# Compute matched prefix length
prefix_computed = False
# rid to deprioritize in the current run.
temporary_deprioritized = {}
if policy == "lpm" or policy == "dfs-weight":
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
# It is used to find the matching prefix for in-batch prefix caching.
temp_radix = RadixCache(None, None, False)
for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
rid=r.rid, key=prefix_ids
)

# NOTE(sang): This logic is for In-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IN_BATCH_PREFIX_CACHING_THRESHOLD > 0`

why >0? should this be > 32?

It is kind of common when the engine is long running (e.g., imagine "the").

What does imagine "the" mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the threshold is 0, this optimization is not applied to prefix like "the", which is common.

Regarding the comment, I just meant == 0 is not ideal because it misses cases like "the" prefix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

32 is also an arbitrary value actually. didn't do much tuning here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense

# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_THRESHOLD:
in_batch_matching_prefixes, _ = temp_radix.match_prefix(
rid=r.rid, key=prefix_ids
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_THRESHOLD
):
temporary_deprioritized[r.rid] = r
else:
temp_radix.insert(prefix_ids, torch.tensor(prefix_ids))

prefix_computed = True

if policy == "lpm":
# Longest Prefix Match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
def get_priority(r: Req):
score = 0
if r.rid in temporary_deprioritized:
score = float("inf")
else:
score = -len(r.prefix_indices)
return score

waiting_queue.sort(key=get_priority)
elif policy == "fcfs":
# first come first serve
pass
Expand All @@ -76,6 +117,7 @@ def calc_priority(self, waiting_queue: List[Req]):
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)

# node -> # of requests for that node.
node_to_weight = defaultdict(int)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
Expand All @@ -87,7 +129,9 @@ def calc_priority(self, waiting_queue: List[Req]):
node_to_weight,
last_node_to_reqs,
waiting_queue,
temporary_deprioritized,
)
waiting_queue.extend(temporary_deprioritized.values())
else:
raise ValueError(f"Unknown schedule_policy: {policy=}")

Expand All @@ -101,15 +145,22 @@ def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
def get_dfs_priority(
self,
cur_node: TreeNode,
node_to_priority: Dict,
last_node_to_reqs: Dict,
node_to_priority: Dict[TreeNode, int],
last_node_to_reqs: Dict[TreeNode, List[Req]],
q: List,
temporary_deprioritized: Dict[str, Req],
):
childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x])
for child in childs:
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
q.extend(last_node_to_reqs[cur_node])
self.get_dfs_priority(
child, node_to_priority, last_node_to_reqs, q, temporary_deprioritized
)

for req in last_node_to_reqs[cur_node]:
if req.rid in temporary_deprioritized:
continue
q.append(req)


class AddReqResult(Enum):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def check_memory(self):
if crash_on_warnings():
raise ValueError(msg)

def get_next_batch_to_run(self):
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.being_chunked_req:
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/mem_cache/base_prefix_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable
from typing import Callable, List, Tuple


class BasePrefixCache(ABC):
Expand All @@ -10,7 +10,7 @@ def reset(self):
pass

@abstractmethod
def match_prefix(self, **kwargs):
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Cache for chunked prefill, used when RadixCache is disabled."""

from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple

from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(
def reset(self):
self.entries = {}

def match_prefix(self, rid: int, key: List[int]):
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
if rid not in self.entries:
return [], None

Expand Down
14 changes: 12 additions & 2 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -76,7 +76,17 @@ def reset(self):
self.root_node.lock_ref = 1
self.evictable_size_ = 0

def match_prefix(self, key: List, **kwargs):
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable:
return [], self.root_node

Expand Down
11 changes: 9 additions & 2 deletions python/sglang/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def status_code(self):
return self.resp.status


def http_request(url, json=None, stream=False, api_key=None, verify=None):
def http_request(
url,
json=None,
stream=False,
api_key=None,
verify=None,
method: Optional[str] = None,
):
"""A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}

Expand All @@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None):
if stream:
return requests.post(url, json=json, stream=True, headers=headers)
else:
req = urllib.request.Request(url, headers=headers)
req = urllib.request.Request(url, headers=headers, method=method)
if json is None:
data = None
else:
Expand Down
Loading