Skip to content

Commit

Permalink
[Misc] Fix metrics, weight update lock, request logging (#2543)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 22, 2024
1 parent 7d672d2 commit 8496701
Show file tree
Hide file tree
Showing 11 changed files with 412 additions and 315 deletions.
280 changes: 110 additions & 170 deletions docs/references/production_metrics.md

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions python/sglang/srt/aio_rwlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio


class RWLock:
"""
A Read-Write Lock for asyncio:
- Multiple readers can hold the lock in parallel if no writer holds it.
- A writer has exclusive access.
"""

def __init__(self):
self._readers = 0 # How many readers currently hold the lock
self._writer_active = False
self._lock = asyncio.Lock() # Internal mutex to protect state
# Conditions associated with _lock:
self._readers_ok = asyncio.Condition(self._lock) # Notify blocked readers
self._writers_ok = asyncio.Condition(self._lock) # Notify blocked writers

# Expose two async context-manager helpers:
self.reader_lock = self._ReaderLock(self)
self.writer_lock = self._WriterLock(self)

async def _acquire_reader(self):
"""
Wait until there is no active writer.
Then increment the count of active readers.
"""
async with self._lock:
# If a writer is active, wait until it's done.
while self._writer_active:
await self._readers_ok.wait()
self._readers += 1

async def _release_reader(self):
"""
Decrement the count of active readers.
If this was the last active reader, wake up a possible waiting writer.
"""
async with self._lock:
self._readers -= 1
# If no more readers, a writer could proceed.
if self._readers == 0:
self._writers_ok.notify()

async def _acquire_writer(self):
"""
Wait until there is no active writer and no active readers.
Then mark a writer as active.
"""
async with self._lock:
while self._writer_active or self._readers > 0:
await self._writers_ok.wait()
self._writer_active = True

async def _release_writer(self):
"""
Mark the writer as done and notify readers and writers.
"""
async with self._lock:
self._writer_active = False
# Allow any waiting readers to proceed:
self._readers_ok.notify_all()
# Allow next waiting writer to proceed:
self._writers_ok.notify()

class _ReaderLock:
"""
A simple async context manager that acquires a reader lock
on entering and releases it on exit.
"""

def __init__(self, parent: "RWLock"):
self._parent = parent

async def __aenter__(self):
await self._parent._acquire_reader()

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._parent._release_reader()

class _WriterLock:
"""
A simple async context manager that acquires a writer lock
on entering and releases it on exit.
"""

def __init__(self, parent: "RWLock"):
self._parent = parent

async def __aenter__(self):
await self._parent._acquire_writer()

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._parent._release_writer()
4 changes: 4 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,12 @@ def __init__(
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size

# Veirfy quantization
self._verify_quantization()

# Multimodel attrs
self.image_token_id = getattr(self.hf_config, "image_token_id", None)

# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
Expand Down
54 changes: 49 additions & 5 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import (
get_bool_env_var,
is_flashinfer_available,
should_use_tensor_core,
)
from sglang.srt.utils import is_flashinfer_available

if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
Expand Down Expand Up @@ -731,3 +727,51 @@ def create_flashinfer_kv_indices_triton(
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)


def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"

# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels

if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass

# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads

# Determine based on dtype and GQA group size
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size > 4
else:
return False
26 changes: 16 additions & 10 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,22 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state):

return True

def reset_for_retract(self):
self.prefix_indices = []
self.last_node = None
self.extend_input_len = 0
self.is_retracted = True

# For incremental logprobs
# TODO: Fix the `logprob_start_len`
self.last_update_decode_tokens = 0
self.logprob_start_len = 10**9

def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
return (
f"rid(n={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
)


bid = 0
Expand Down Expand Up @@ -894,15 +908,7 @@ def retract_decode(self):
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)

req.prefix_indices = []
req.last_node = None
req.extend_input_len = 0
req.is_retracted = True

# For incremental logprobs
req.last_update_decode_tokens = 0
req.logprob_start_len = 10**9
req.reset_for_retract()

self.filter_batch(keep_indices=sorted_indices)

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections import deque
from concurrent import futures
from types import SimpleNamespace
from typing import List, Optional
from typing import Callable, Dict, List, Optional, Tuple

import psutil
import setproctitle
Expand Down Expand Up @@ -260,7 +260,7 @@ def __init__(
self.current_stream = torch.get_device_module(self.device).current_stream()

# Session info
self.sessions = {}
self.sessions: Dict[str, Session] = {}

# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
Expand Down
Loading

0 comments on commit 8496701

Please sign in to comment.