From a1be960d7caf6b8b62b908cba7485968146b79a1 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 8 Dec 2024 03:59:12 -0800 Subject: [PATCH] Fix stream --- python/sglang/srt/managers/scheduler.py | 59 ++++++++++++------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d98499d6278..e92e5191eb4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1193,22 +1193,16 @@ def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None): output_skip_special_tokens = [] output_spaces_between_special_tokens = [] output_no_stop_trim = [] - else: # embedding or reward model - output_embeddings = [] - - is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 - for req in reqs: - if req is skip_req: - continue + for req in reqs: + if req is skip_req: + continue - # TODO(lianmin): revisit this for overlap + retract + stream - if req.finished() or ( - req.stream and (is_stream_iter or len(req.output_ids) == 1) - ): - output_rids.append(req.rid) - output_finished_reason.append(req.finished_reason) - if self.is_generation: + # TODO(lianmin): revisit this for overlap + retract + stream + is_stream_iter = len(req.output_ids) % self.stream_interval == 0 + if req.finished() or (req.stream and is_stream_iter): + output_rids.append(req.rid) + output_finished_reason.append(req.finished_reason) output_vids.append(req.vid) decoded_texts.append(req.decoded_text) read_ids, read_offset = req.init_incremental_detokenize() @@ -1250,16 +1244,9 @@ def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None): req.normalized_prompt_logprob, ) output_meta_info.append(meta_info) - else: # embedding or reward model - output_embeddings.append(req.embedding) - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - } - output_meta_info.append(meta_info) - # Send to detokenizer - if output_rids: - if self.is_generation: + # Send to detokenizer + if output_rids: self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( output_rids, @@ -1275,15 +1262,25 @@ def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None): output_no_stop_trim, ) ) - else: # embedding or reward model - self.send_to_detokenizer.send_pyobj( - BatchEmbeddingOut( - output_rids, - output_embeddings, - output_meta_info, - output_finished_reason, - ) + else: # embedding or reward model + output_embeddings = [] + for req in reqs: + assert req.finished() + output_rids.append(req.rid) + output_finished_reason.append(req.finished_reason) + output_embeddings.append(req.embedding) + meta_info = { + "prompt_tokens": len(req.origin_input_ids), + } + output_meta_info.append(meta_info) + self.send_to_detokenizer.send_pyobj( + BatchEmbeddingOut( + output_rids, + output_embeddings, + output_meta_info, + output_finished_reason, ) + ) def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # Check if other DP workers have running batches