From 0e7409adb64ac19db2db3583ef3e4077cc569b30 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 6 Dec 2024 05:49:29 -0800 Subject: [PATCH] Fix the overlap for xgrammar (#2377) --- docs/references/supported_models.md | 2 +- .../srt/constrained/outlines_backend.py | 5 + .../srt/constrained/xgrammar_backend.py | 10 +- python/sglang/srt/managers/scheduler.py | 134 +++++++++--------- .../srt/managers/tp_worker_overlap_thread.py | 3 +- .../srt/sampling/sampling_batch_info.py | 17 +-- test/srt/test_json_constrained.py | 107 +++++++------- 7 files changed, 145 insertions(+), 133 deletions(-) diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 13572e437d5..bf1044f8498 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -106,4 +106,4 @@ def import_new_model_classes(): ModelRegistry.models.update(import_new_model_classes()) launch_server(server_args) -``` \ No newline at end of file +``` diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 26c476a0599..4820d473959 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -42,6 +42,7 @@ def __init__( self.guide = guide self.jump_forward_map = jump_forward_map self.state = 0 + self.finished = False def accept_token(self, token: int): self.state = self.guide.get_next_state(self.state, token) @@ -84,6 +85,10 @@ def allocate_vocab_mask( ) -> torch.Tensor: return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: tokens = torch.tensor( self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 1bcc51c6468..ee8e8eb07f4 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -45,6 +45,7 @@ def __init__( self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx + self.finished = False def accept_token(self, token: int): assert self.matcher.accept_token(token) @@ -85,12 +86,11 @@ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(vocab_mask, idx) @staticmethod - def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - if vocab_mask.device.type != logits.device.type: - # vocab_mask must then be on the same device as logits - # when applying the token bitmask, so we check and move if needed - vocab_mask = vocab_mask.to(logits.device) + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask.to(device, non_blocking=True) + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: apply_token_bitmask_inplace(logits, vocab_mask) def copy(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fd4edade92d..4ca4cd740dc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -114,9 +114,6 @@ def __init__( self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics - # Session info - self.sessions = {} - # Init inter-process communication context = zmq.Context(2) @@ -259,6 +256,10 @@ def __init__( self.num_generated_tokens = 0 self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval + self.current_stream = torch.get_device_module(self.device).current_stream() + + # Session info + self.sessions = {} # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size @@ -356,6 +357,7 @@ def __init__( ) def watchdog_thread(self): + """A watch dog thread that will try to kill the server itself if one batch takes too long.""" self.watchdog_last_forward_ct = 0 self.watchdog_last_time = time.time() @@ -433,61 +435,6 @@ def event_loop_overlap(self): self.last_batch = batch - def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): - # Check if other DP workers have running batches - if local_batch is None: - num_tokens = 0 - elif local_batch.forward_mode.is_decode(): - num_tokens = local_batch.batch_size() - else: - num_tokens = local_batch.extend_num_tokens - - local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) - global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) - torch.distributed.all_gather_into_tensor( - global_num_tokens, - local_num_tokens, - group=self.tp_cpu_group, - ) - - if local_batch is None and global_num_tokens.max().item() > 0: - local_batch = self.get_idle_batch() - - if local_batch is not None: - local_batch.global_num_tokens = global_num_tokens.tolist() - - # Check forward mode for cuda graph - if not self.server_args.disable_cuda_graph: - forward_mode_state = torch.tensor( - ( - 1 - if local_batch.forward_mode.is_decode() - or local_batch.forward_mode.is_idle() - else 0 - ), - dtype=torch.int32, - ) - torch.distributed.all_reduce( - forward_mode_state, - op=torch.distributed.ReduceOp.MIN, - group=self.tp_cpu_group, - ) - local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1 - - return local_batch - - def get_idle_batch(self): - idle_batch = ScheduleBatch.init_new( - [], - self.req_to_token_pool, - self.token_to_kv_pool, - self.tree_cache, - self.model_config, - self.enable_overlap, - ) - idle_batch.prepare_for_idle() - return idle_batch - def recv_requests(self): if self.tp_rank == 0 or self.server_args.enable_dp_attention: recv_reqs = [] @@ -993,7 +940,7 @@ def process_batch_result(self, batch: ScheduleBatch, result): self.process_batch_result_prefill(batch, result) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.get_device_module(self.device).current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() def process_batch_result_prefill(self, batch: ScheduleBatch, result): @@ -1049,13 +996,14 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): if req.grammar is not None: req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() else: # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.get_device_module(self.device).current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model @@ -1127,10 +1075,11 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): if req.grammar is not None: req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.get_device_module(self.device).current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() self.stream_output(batch.reqs) @@ -1328,6 +1277,61 @@ def stream_output(self, reqs: List[Req]): ) ) + def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): + # Check if other DP workers have running batches + if local_batch is None: + num_tokens = 0 + elif local_batch.forward_mode.is_decode(): + num_tokens = local_batch.batch_size() + else: + num_tokens = local_batch.extend_num_tokens + + local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) + global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) + torch.distributed.all_gather_into_tensor( + global_num_tokens, + local_num_tokens, + group=self.tp_cpu_group, + ) + + if local_batch is None and global_num_tokens.max().item() > 0: + local_batch = self.get_idle_batch() + + if local_batch is not None: + local_batch.global_num_tokens = global_num_tokens.tolist() + + # Check forward mode for cuda graph + if not self.server_args.disable_cuda_graph: + forward_mode_state = torch.tensor( + ( + 1 + if local_batch.forward_mode.is_decode() + or local_batch.forward_mode.is_idle() + else 0 + ), + dtype=torch.int32, + ) + torch.distributed.all_reduce( + forward_mode_state, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_cpu_group, + ) + local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1 + + return local_batch + + def get_idle_batch(self): + idle_batch = ScheduleBatch.init_new( + [], + self.req_to_token_pool, + self.token_to_kv_pool, + self.tree_cache, + self.model_config, + self.enable_overlap, + ) + idle_batch.prepare_for_idle() + return idle_batch + def move_ready_grammar_requests(self): """Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" num_ready_reqs = 0 @@ -1469,10 +1473,6 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): - # set cpu affinity to this gpu process - if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): - set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) - # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var if dp_rank is None and "SGLANG_DP_RANK" in os.environ: dp_rank = int(os.environ["SGLANG_DP_RANK"]) @@ -1482,6 +1482,10 @@ def run_scheduler_process( else: configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") + # set cpu affinity to this gpu process + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + suppress_other_loggers() parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 6a453d2ad6d..a9db1878391 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -80,6 +80,7 @@ def __init__( ) self.forward_thread.start() self.parent_process = psutil.Process().parent() + self.scheduler_stream = torch.get_device_module(self.device).current_stream() def get_worker_info(self): return self.worker.get_worker_info() @@ -191,7 +192,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): ) # A cuda stream sync here to avoid the cuda illegal memory access error. - torch.get_device_module(self.device).current_stream().synchronize() + self.scheduler_stream.synchronize() # Push a new batch to the queue self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 1624fd255f9..a64a84a62dc 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -158,22 +158,23 @@ def update_regex_vocab_mask(self): return # find a grammar from the list - grammar = next(grammar for grammar in self.grammars if grammar) + first_grammar = next(grammar for grammar in self.grammars if grammar) # maybe we can reuse the existing mask? - self.vocab_mask = grammar.allocate_vocab_mask( + self.vocab_mask = first_grammar.allocate_vocab_mask( vocab_size=self.vocab_size, batch_size=len(self.temperatures), device=self.device, ) - self.apply_mask = type(grammar).apply_vocab_mask # force to use static method + self.apply_mask = first_grammar.apply_vocab_mask # force to use static method + # Apply the mask for i, grammar in enumerate(self.grammars): - if grammar is not None: - try: - grammar.fill_vocab_mask(self.vocab_mask, i) - except RuntimeError: - continue + if grammar and not grammar.finished: + grammar.fill_vocab_mask(self.vocab_mask, i) + + # Move the mask to the device if needed + self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 28acdabd9d0..1a857d0da6e 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -1,5 +1,6 @@ """ -python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate +python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate +python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate """ import json @@ -11,38 +12,50 @@ from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) +def setup_class(cls, backend: str, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + backend, + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + class TestJSONConstrainedOutlinesBackend(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string", "pattern": "^[\\w]+$"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=300, - other_args=[ - "--max-running-requests", - "10", - "--grammar-backend", - "outlines", - ], - ) + setup_class(cls, backend="outlines", disable_overlap=False) + cls.check_jump_forward = False @classmethod def tearDownClass(cls): @@ -83,11 +96,13 @@ def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1) self.assertIsInstance(js_obj["population"], int) # Make sure jump forward is triggered - # NOTE: This is skipped because overlap scheduler does not support jump forward - # self.assertGreater( - # ret["meta_info"]["completion_tokens"], - # ret["meta_info"]["completion_tokens_wo_jump_forward"], - # ) + # NOTE: The overlap scheduler does not support jump forward so we only do this test + # when --disable-overlap-schedule is set. + if self.check_jump_forward: + self.assertGreater( + ret["meta_info"]["completion_tokens"], + ret["meta_info"]["completion_tokens_wo_jump_forward"], + ) def test_json_generate(self): self.run_decode(json_schema=self.json_schema) @@ -126,32 +141,18 @@ def test_mix_json_and_other(self): list(executor.map(self.run_decode, json_schemas)) +class TestJumpForwardOutlinesBackend(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, backend="outlines", disable_overlap=True) + cls.check_jump_forward = True + + class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=300, - other_args=[ - "--max-running-requests", - "10", - "--grammar-backend", - "xgrammar", - ], - ) + setup_class(cls, backend="xgrammar", disable_overlap=False) + cls.check_jump_forward = False if __name__ == "__main__":