Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the overlap for xgrammar #2377

Merged
merged 7 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/references/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@ def import_new_model_classes():
ModelRegistry.models.update(import_new_model_classes())

launch_server(server_args)
```
```
5 changes: 5 additions & 0 deletions python/sglang/srt/constrained/outlines_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
self.guide = guide
self.jump_forward_map = jump_forward_map
self.state = 0
self.finished = False

def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token)
Expand Down Expand Up @@ -84,6 +85,10 @@ def allocate_vocab_mask(
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)

@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask

def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.finished = False

def accept_token(self, token: int):
assert self.matcher.accept_token(token)
Expand Down Expand Up @@ -85,12 +86,11 @@ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx)

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if vocab_mask.device.type != logits.device.type:
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask = vocab_mask.to(logits.device)
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
apply_token_bitmask_inplace(logits, vocab_mask)

def copy(self):
Expand Down
134 changes: 69 additions & 65 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ def __init__(
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics

# Session info
self.sessions = {}

# Init inter-process communication
context = zmq.Context(2)

Expand Down Expand Up @@ -259,6 +256,10 @@ def __init__(
self.num_generated_tokens = 0
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream()

# Session info
self.sessions = {}

# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
Expand Down Expand Up @@ -356,6 +357,7 @@ def __init__(
)

def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()

Expand Down Expand Up @@ -433,61 +435,6 @@ def event_loop_overlap(self):

self.last_batch = batch

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens

local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_cpu_group,
)

if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()

# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

return local_batch

def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
self.enable_overlap,
)
idle_batch.prepare_for_idle()
return idle_batch

def recv_requests(self):
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = []
Expand Down Expand Up @@ -993,7 +940,7 @@ def process_batch_result(self, batch: ScheduleBatch, result):
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Expand Down Expand Up @@ -1049,13 +996,14 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
else:
# being chunked reqs' prefill is not finished
req.is_being_chunked -= 1

if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

else: # embedding or reward model
Expand Down Expand Up @@ -1127,10 +1075,11 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()

if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

self.stream_output(batch.reqs)
Expand Down Expand Up @@ -1328,6 +1277,61 @@ def stream_output(self, reqs: List[Req]):
)
)

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens

local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_cpu_group,
)

if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()

# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

return local_batch

def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
self.enable_overlap,
)
idle_batch.prepare_for_idle()
return idle_batch

def move_ready_grammar_requests(self):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0
Expand Down Expand Up @@ -1469,10 +1473,6 @@ def run_scheduler_process(
dp_rank: Optional[int],
pipe_writer,
):
# set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"])
Expand All @@ -1482,6 +1482,10 @@ def run_scheduler_process(
else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

# set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

suppress_other_loggers()
parent_process = psutil.Process().parent()

Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
)
self.forward_thread.start()
self.parent_process = psutil.Process().parent()
self.scheduler_stream = torch.get_device_module(self.device).current_stream()

def get_worker_info(self):
return self.worker.get_worker_info()
Expand Down Expand Up @@ -191,7 +192,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
)

# A cuda stream sync here to avoid the cuda illegal memory access error.
torch.get_device_module(self.device).current_stream().synchronize()
self.scheduler_stream.synchronize()

# Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
Expand Down
17 changes: 9 additions & 8 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,23 @@ def update_regex_vocab_mask(self):
return

# find a grammar from the list
grammar = next(grammar for grammar in self.grammars if grammar)
first_grammar = next(grammar for grammar in self.grammars if grammar)

# maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask(
self.vocab_mask = first_grammar.allocate_vocab_mask(
vocab_size=self.vocab_size,
batch_size=len(self.temperatures),
device=self.device,
)
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method

# Apply the mask
for i, grammar in enumerate(self.grammars):
if grammar is not None:
try:
grammar.fill_vocab_mask(self.vocab_mask, i)
except RuntimeError:
continue
if grammar and not grammar.finished:
grammar.fill_vocab_mask(self.vocab_mask, i)

# Move the mask to the device if needed
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)

def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
Expand Down
Loading
Loading