Skip to content

Commit

Permalink
Fix the overlap for xgrammar (#2377)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Dec 6, 2024
1 parent 3cde5eb commit 0e7409a
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 133 deletions.
2 changes: 1 addition & 1 deletion docs/references/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@ def import_new_model_classes():
ModelRegistry.models.update(import_new_model_classes())

launch_server(server_args)
```
```
5 changes: 5 additions & 0 deletions python/sglang/srt/constrained/outlines_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
self.guide = guide
self.jump_forward_map = jump_forward_map
self.state = 0
self.finished = False

def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token)
Expand Down Expand Up @@ -84,6 +85,10 @@ def allocate_vocab_mask(
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)

@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask

def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.finished = False

def accept_token(self, token: int):
assert self.matcher.accept_token(token)
Expand Down Expand Up @@ -85,12 +86,11 @@ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx)

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if vocab_mask.device.type != logits.device.type:
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask = vocab_mask.to(logits.device)
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
apply_token_bitmask_inplace(logits, vocab_mask)

def copy(self):
Expand Down
134 changes: 69 additions & 65 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ def __init__(
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics

# Session info
self.sessions = {}

# Init inter-process communication
context = zmq.Context(2)

Expand Down Expand Up @@ -259,6 +256,10 @@ def __init__(
self.num_generated_tokens = 0
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream()

# Session info
self.sessions = {}

# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
Expand Down Expand Up @@ -356,6 +357,7 @@ def __init__(
)

def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()

Expand Down Expand Up @@ -433,61 +435,6 @@ def event_loop_overlap(self):

self.last_batch = batch

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens

local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_cpu_group,
)

if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()

# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

return local_batch

def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
self.enable_overlap,
)
idle_batch.prepare_for_idle()
return idle_batch

def recv_requests(self):
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = []
Expand Down Expand Up @@ -993,7 +940,7 @@ def process_batch_result(self, batch: ScheduleBatch, result):
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Expand Down Expand Up @@ -1049,13 +996,14 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
else:
# being chunked reqs' prefill is not finished
req.is_being_chunked -= 1

if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

else: # embedding or reward model
Expand Down Expand Up @@ -1127,10 +1075,11 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()

if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

self.stream_output(batch.reqs)
Expand Down Expand Up @@ -1328,6 +1277,61 @@ def stream_output(self, reqs: List[Req]):
)
)

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens

local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_cpu_group,
)

if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()

# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

return local_batch

def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
self.enable_overlap,
)
idle_batch.prepare_for_idle()
return idle_batch

def move_ready_grammar_requests(self):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0
Expand Down Expand Up @@ -1469,10 +1473,6 @@ def run_scheduler_process(
dp_rank: Optional[int],
pipe_writer,
):
# set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"])
Expand All @@ -1482,6 +1482,10 @@ def run_scheduler_process(
else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

# set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

suppress_other_loggers()
parent_process = psutil.Process().parent()

Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
)
self.forward_thread.start()
self.parent_process = psutil.Process().parent()
self.scheduler_stream = torch.get_device_module(self.device).current_stream()

def get_worker_info(self):
return self.worker.get_worker_info()
Expand Down Expand Up @@ -191,7 +192,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
)

# A cuda stream sync here to avoid the cuda illegal memory access error.
torch.get_device_module(self.device).current_stream().synchronize()
self.scheduler_stream.synchronize()

# Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
Expand Down
17 changes: 9 additions & 8 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,23 @@ def update_regex_vocab_mask(self):
return

# find a grammar from the list
grammar = next(grammar for grammar in self.grammars if grammar)
first_grammar = next(grammar for grammar in self.grammars if grammar)

# maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask(
self.vocab_mask = first_grammar.allocate_vocab_mask(
vocab_size=self.vocab_size,
batch_size=len(self.temperatures),
device=self.device,
)
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method

# Apply the mask
for i, grammar in enumerate(self.grammars):
if grammar is not None:
try:
grammar.fill_vocab_mask(self.vocab_mask, i)
except RuntimeError:
continue
if grammar and not grammar.finished:
grammar.fill_vocab_mask(self.vocab_mask, i)

# Move the mask to the device if needed
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)

def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
Expand Down
Loading

0 comments on commit 0e7409a

Please sign in to comment.