From 69e2d4fb66e8dd9df7e9472df44ae29afc1320d1 Mon Sep 17 00:00:00 2001 From: HAI Date: Mon, 2 Dec 2024 19:05:58 -0800 Subject: [PATCH 01/60] Relax to include more AMD GPUs (#2319) --- python/sglang/srt/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c19d521a066..04372bac194 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -879,7 +879,7 @@ def get_amdgpu_memory_capacity(): # Run rocm-smi and capture the output result = subprocess.run( [ - "rocminfo | grep 'gfx94' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'" + "rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'" ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, From 480e38a73350f2af57d003b023fab5cbc9a1e65e Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Mon, 2 Dec 2024 20:19:02 -0800 Subject: [PATCH 02/60] [feat] Enable chunked prefill for llava-onevision (#2281) --- python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 13 +- python/sglang/srt/models/llava.py | 51 ++++-- test/srt/run_suite.py | 1 + test/srt/test_vision_chunked_prefill.py | 173 ++++++++++++++++++ 5 files changed, 221 insertions(+), 18 deletions(-) create mode 100644 test/srt/test_vision_chunked_prefill.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 28677efeac4..301ef4bb74b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -128,6 +128,7 @@ class ImageInputs: image_hashes: Optional[list] = None image_sizes: Optional[list] = None image_offsets: Optional[list] = None + image_pad_len: Optional[list] = None pad_values: Optional[list] = None modalities: Optional[list] = None num_image_tokens: Optional[int] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 73dec4a9cf8..657e0c2ca5d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -110,15 +110,20 @@ def __init__( ) if self.is_multimodal: - logger.info( - "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." - ) - server_args.chunked_prefill_size = -1 self.mem_fraction_static *= 0.95 + if self.model_config.hf_config.architectures == [ + "MllamaForConditionalGeneration" + ]: + logger.info("Automatically turn off --chunked-prefill-size for mllama.") + server_args.chunked_prefill_size = -1 # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" ]: + logger.info( + "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl." + ) + server_args.chunked_prefill_size = -1 server_args.disable_radix_cache = True # Global vars diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 4c62dbb25f1..c8ce9302b4f 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -57,6 +57,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): else: image_aspect_ratio = "anyres" offset_list = [] + image_inputs.image_pad_len = [] for image_idx, image_s in enumerate(image_sizes): if len(image_sizes) > 16: # 2x2 pooling with stride 2 @@ -103,6 +104,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + input_ids[offset + 1 :] ) offset_list.append(offset) + image_inputs.image_pad_len.append(new_image_feature_len) image_inputs.image_offsets = offset_list return input_ids @@ -134,6 +136,14 @@ def forward( image_inputs = forward_batch.image_inputs if forward_batch.forward_mode.is_extend(): + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + + # Embed text inputs + input_embeds = self.language_model.model.embed_tokens(input_ids) + # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size modalities_list = [] @@ -142,18 +152,12 @@ def forward( if im and im.modalities is not None: modalities_list.extend(im.modalities) if im and im.image_offsets: - max_image_offset.append(max(im.image_offsets)) + max_image_offset.append( + np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)) + ) else: max_image_offset.append(-1) - # Clamp input ids. This is because the input_ids for the image tokens are - # filled with the hash values of the image for the prefix matching in the radix attention. - # There values are useless because their embeddings will be replaced by vision embeddings anyway. - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - - # Embed text inputs - input_embeds = self.language_model.model.embed_tokens(input_ids) - start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() need_vision = start_positions <= np.array(max_image_offset) @@ -350,6 +354,7 @@ def forward( # Fill in the placeholder for the image extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() + extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu pt = 0 for i in range(bs): @@ -357,18 +362,36 @@ def forward( continue start_idx = extend_start_loc_cpu[i] + seq_len = extend_seq_lens[i] prefix_len = prefix_lens_cpu[i] # Multiple images - for j, image_offset in enumerate(image_inputs[i].image_offsets): - if image_offset < prefix_len: + for image_idx, image_offset in enumerate( + image_inputs[i].image_offsets + ): + if ( + image_offset + image_inputs[i].image_pad_len[image_idx] + <= prefix_len + ): continue + if image_offset >= prefix_len + seq_len: + break - tmp_image_feature = image_features[pt][j] + tmp_image_feature = image_features[pt][image_idx] pad_len = tmp_image_feature.shape[0] - left_idx = start_idx + (image_offset - prefix_len) - right_idx = start_idx + (image_offset - prefix_len) + pad_len + input_offset = image_offset - prefix_len + left_idx = start_idx + input_offset + right_idx = left_idx + pad_len + assert right_idx > start_idx + if input_offset < 0: + left_idx = start_idx + tmp_image_feature = tmp_image_feature[-input_offset:] + if right_idx > start_idx + seq_len: + tmp_image_feature = tmp_image_feature[ + : start_idx + seq_len - right_idx + ] + right_idx = start_idx + seq_len try: input_embeds[left_idx:right_idx] = tmp_image_feature except RuntimeError as e: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5035810f86a..89fdacfa231 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -39,6 +39,7 @@ "test_triton_attention_kernels.py", "test_triton_attention_backend.py", "test_update_weights_from_disk.py", + "test_vision_chunked_prefill.py", "test_vision_openai_server.py", "test_session_control.py", ], diff --git a/test/srt/test_vision_chunked_prefill.py b/test/srt/test_vision_chunked_prefill.py new file mode 100644 index 00000000000..f7725f17bee --- /dev/null +++ b/test/srt/test_vision_chunked_prefill.py @@ -0,0 +1,173 @@ +""" +Usage: +python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill +""" + +import base64 +import io +import os +import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Union + +import numpy as np +import requests +from decord import VideoReader, cpu +from PIL import Image + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestVisionChunkedPrefill(unittest.TestCase): + def prepare_video_messages(self, video_path, max_frames_num=8): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video briefly."} + messages[0]["content"].append(prompt) + + return messages + + def get_prompt_from_messages(self, messages): + text = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + ) + image_data = [] + for content in messages[0]["content"]: + if content["type"] == "image_url": + text += "\n" + image_data.append(content["image_url"]["url"]) + text += "Please describe the video briefly.<|im_end|>\n<|im_start|>assistant\n" + return text, image_data + + def generate(self, text, image_data): + response = requests.post( + self.base_url + "/generate", + json={ + "text": text, + "image_data": image_data, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + "modalities": ["multi-images"], + }, + ).json() + return response["text"] + + def generate_for_video(self, batch, num_frame) -> Union[str, list[str]]: + # prepare the video input about Steven introducing ipod nano + url = "https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + if not os.path.exists(file_path): + response = requests.get(url) + response.raise_for_status() + with open(file_path, "wb") as f: + f.write(response.content) + + if not batch: + assert isinstance(num_frame, int) + messages = self.prepare_video_messages(file_path, max_frames_num=num_frame) + text, image_data = self.get_prompt_from_messages(messages) + return self.generate(text, image_data) + else: + assert isinstance(num_frame, list) + func_args = [] + for max_frames_num in num_frame: + messages = self.prepare_video_messages( + file_path, + max_frames_num=max_frames_num, + ) + text, image_data = self.get_prompt_from_messages(messages) + func_args.append((text, image_data)) + + with ThreadPoolExecutor(max_workers=10) as executor: + responses = list(executor.map(lambda p: self.generate(*p), func_args)) + + return responses + + def run_generate(self, chunked_prefill_size, batch, num_frame): + # launch server + model = "lmms-lab/llava-onevision-qwen2-7b-ov" + # model = "meta-llama/Llama-3.2-11B-Vision-Instruct" + self.base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + f"{chunked_prefill_size}", + ], + ) + try: + return self.generate_for_video(batch, num_frame) + finally: + kill_process_tree(process.pid) + + def test_chunked_prefill(self): + output_chunked = self.run_generate( + chunked_prefill_size=1024, batch=False, num_frame=1 + ) + output_no_chunked = self.run_generate( + chunked_prefill_size=-1, batch=False, num_frame=1 + ) + + print("output with chunked prefill:") + print(output_chunked) + print("output without chunked prefill:") + print(output_no_chunked) + assert output_chunked == output_no_chunked + + output_chunked = self.run_generate( + chunked_prefill_size=1024, batch=True, num_frame=[2, 6, 8, 10] + ) + output_no_chunked = self.run_generate( + chunked_prefill_size=-1, batch=True, num_frame=[2, 6, 8, 10] + ) + + print("output with chunked prefill:") + print(output_chunked) + print("output without chunked prefill:") + print(output_no_chunked) + assert output_chunked == output_no_chunked + + +if __name__ == "__main__": + unittest.main() From 3ddb1c467979eb13afc629506ea80806935390e8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 2 Dec 2024 20:45:53 -0800 Subject: [PATCH 03/60] [Minor] Fix logger and style (#2325) --- python/sglang/bench_serving.py | 1 - .../sglang/srt/model_executor/model_runner.py | 5 ++++- python/sglang/srt/server_args.py | 19 ++++++++++++------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 3eca72de4aa..1a909caa812 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -163,7 +163,6 @@ async def async_request_openai_completions( "max_tokens": request_func_input.output_len, "stream": not args.disable_stream, "ignore_eos": not args.disable_ignore_eos, - "lora_path": request_func_input.lora_name, **request_func_input.extra_request_body, } headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 657e0c2ca5d..24a28595204 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -16,6 +16,7 @@ import gc import json import logging +import time from typing import Optional import torch @@ -129,7 +130,7 @@ def __init__( # Global vars if server_args.show_time_cost: enable_show_time_cost() - if server_args.disable_disk_cache: + if server_args.disable_outlines_disk_cache: from outlines.caching import disable_cache disable_cache() @@ -623,8 +624,10 @@ def init_cuda_graphs(self): if self.server_args.disable_cuda_graph: return + tic = time.time() logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s") def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 37ad6cfc530..788686a1ee8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -122,7 +122,7 @@ class ServerArgs: disable_jump_forward: bool = False disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False - disable_disk_cache: bool = False + disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False disable_mla: bool = False disable_overlap_schedule: bool = False @@ -159,7 +159,7 @@ def __post_init__(self): if self.tp_size >= 16: self.mem_fraction_static = 0.79 elif self.tp_size >= 8: - self.mem_fraction_static = 0.82 + self.mem_fraction_static = 0.81 elif self.tp_size >= 4: self.mem_fraction_static = 0.85 elif self.tp_size >= 2: @@ -192,7 +192,7 @@ def __post_init__(self): ) if self.attention_backend == "torch_native": - logger.info( + logger.warning( "Cuda graph is disabled because of using torch native attention backend" ) self.disable_cuda_graph = True @@ -204,12 +204,12 @@ def __post_init__(self): self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.disable_overlap_schedule = True - logger.info( + logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " "Data parallel size is adjusted to be the same as tensor parallel size. " - "Overlap schedule is disabled." + "Overlap scheduler is disabled." ) # GGUF @@ -642,9 +642,9 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.", ) parser.add_argument( - "--disable-disk-cache", + "--disable-outlines-disk-cache", action="store_true", - help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", + help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.", ) parser.add_argument( "--disable-custom-all-reduce", @@ -745,6 +745,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action=DeprecatedAction, help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.", ) + parser.add_argument( + "--disable-disk-cache", + action=DeprecatedAction, + help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): From aa47f642230f35269b45d81cba837a30a3015eb3 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Mon, 2 Dec 2024 23:11:13 -0800 Subject: [PATCH 04/60] Revert "[feat] Enable chunked prefill for llava-onevision" (#2329) --- python/sglang/srt/managers/schedule_batch.py | 1 - .../sglang/srt/model_executor/model_runner.py | 13 +- python/sglang/srt/models/llava.py | 51 ++---- test/srt/run_suite.py | 1 - test/srt/test_vision_chunked_prefill.py | 173 ------------------ 5 files changed, 18 insertions(+), 221 deletions(-) delete mode 100644 test/srt/test_vision_chunked_prefill.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 301ef4bb74b..28677efeac4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -128,7 +128,6 @@ class ImageInputs: image_hashes: Optional[list] = None image_sizes: Optional[list] = None image_offsets: Optional[list] = None - image_pad_len: Optional[list] = None pad_values: Optional[list] = None modalities: Optional[list] = None num_image_tokens: Optional[int] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 24a28595204..74a7d1fc56c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -111,20 +111,15 @@ def __init__( ) if self.is_multimodal: + logger.info( + "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." + ) + server_args.chunked_prefill_size = -1 self.mem_fraction_static *= 0.95 - if self.model_config.hf_config.architectures == [ - "MllamaForConditionalGeneration" - ]: - logger.info("Automatically turn off --chunked-prefill-size for mllama.") - server_args.chunked_prefill_size = -1 # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" ]: - logger.info( - "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl." - ) - server_args.chunked_prefill_size = -1 server_args.disable_radix_cache = True # Global vars diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index c8ce9302b4f..4c62dbb25f1 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -57,7 +57,6 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): else: image_aspect_ratio = "anyres" offset_list = [] - image_inputs.image_pad_len = [] for image_idx, image_s in enumerate(image_sizes): if len(image_sizes) > 16: # 2x2 pooling with stride 2 @@ -104,7 +103,6 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + input_ids[offset + 1 :] ) offset_list.append(offset) - image_inputs.image_pad_len.append(new_image_feature_len) image_inputs.image_offsets = offset_list return input_ids @@ -136,14 +134,6 @@ def forward( image_inputs = forward_batch.image_inputs if forward_batch.forward_mode.is_extend(): - # Clamp input ids. This is because the input_ids for the image tokens are - # filled with the hash values of the image for the prefix matching in the radix attention. - # There values are useless because their embeddings will be replaced by vision embeddings anyway. - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - - # Embed text inputs - input_embeds = self.language_model.model.embed_tokens(input_ids) - # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size modalities_list = [] @@ -152,12 +142,18 @@ def forward( if im and im.modalities is not None: modalities_list.extend(im.modalities) if im and im.image_offsets: - max_image_offset.append( - np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)) - ) + max_image_offset.append(max(im.image_offsets)) else: max_image_offset.append(-1) + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + + # Embed text inputs + input_embeds = self.language_model.model.embed_tokens(input_ids) + start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() need_vision = start_positions <= np.array(max_image_offset) @@ -354,7 +350,6 @@ def forward( # Fill in the placeholder for the image extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() - extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu pt = 0 for i in range(bs): @@ -362,36 +357,18 @@ def forward( continue start_idx = extend_start_loc_cpu[i] - seq_len = extend_seq_lens[i] prefix_len = prefix_lens_cpu[i] # Multiple images - for image_idx, image_offset in enumerate( - image_inputs[i].image_offsets - ): - if ( - image_offset + image_inputs[i].image_pad_len[image_idx] - <= prefix_len - ): + for j, image_offset in enumerate(image_inputs[i].image_offsets): + if image_offset < prefix_len: continue - if image_offset >= prefix_len + seq_len: - break - tmp_image_feature = image_features[pt][image_idx] + tmp_image_feature = image_features[pt][j] pad_len = tmp_image_feature.shape[0] - input_offset = image_offset - prefix_len - left_idx = start_idx + input_offset - right_idx = left_idx + pad_len - assert right_idx > start_idx - if input_offset < 0: - left_idx = start_idx - tmp_image_feature = tmp_image_feature[-input_offset:] - if right_idx > start_idx + seq_len: - tmp_image_feature = tmp_image_feature[ - : start_idx + seq_len - right_idx - ] - right_idx = start_idx + seq_len + left_idx = start_idx + (image_offset - prefix_len) + right_idx = start_idx + (image_offset - prefix_len) + pad_len try: input_embeds[left_idx:right_idx] = tmp_image_feature except RuntimeError as e: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 89fdacfa231..5035810f86a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -39,7 +39,6 @@ "test_triton_attention_kernels.py", "test_triton_attention_backend.py", "test_update_weights_from_disk.py", - "test_vision_chunked_prefill.py", "test_vision_openai_server.py", "test_session_control.py", ], diff --git a/test/srt/test_vision_chunked_prefill.py b/test/srt/test_vision_chunked_prefill.py deleted file mode 100644 index f7725f17bee..00000000000 --- a/test/srt/test_vision_chunked_prefill.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -Usage: -python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill -""" - -import base64 -import io -import os -import unittest -from concurrent.futures import ThreadPoolExecutor -from typing import Union - -import numpy as np -import requests -from decord import VideoReader, cpu -from PIL import Image - -from sglang.srt.utils import kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - popen_launch_server, -) - - -class TestVisionChunkedPrefill(unittest.TestCase): - def prepare_video_messages(self, video_path, max_frames_num=8): - vr = VideoReader(video_path, ctx=cpu(0)) - total_frame_num = len(vr) - uniform_sampled_frames = np.linspace( - 0, total_frame_num - 1, max_frames_num, dtype=int - ) - frame_idx = uniform_sampled_frames.tolist() - frames = vr.get_batch(frame_idx).asnumpy() - - base64_frames = [] - for frame in frames: - pil_img = Image.fromarray(frame) - buff = io.BytesIO() - pil_img.save(buff, format="JPEG") - base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") - base64_frames.append(base64_str) - - messages = [{"role": "user", "content": []}] - frame_format = { - "type": "image_url", - "image_url": {"url": "data:image/jpeg;base64,{}"}, - "modalities": "video", - } - - for base64_frame in base64_frames: - frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( - base64_frame - ) - messages[0]["content"].append(frame_format.copy()) - - prompt = {"type": "text", "text": "Please describe the video briefly."} - messages[0]["content"].append(prompt) - - return messages - - def get_prompt_from_messages(self, messages): - text = ( - "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - "<|im_start|>user\n" - ) - image_data = [] - for content in messages[0]["content"]: - if content["type"] == "image_url": - text += "\n" - image_data.append(content["image_url"]["url"]) - text += "Please describe the video briefly.<|im_end|>\n<|im_start|>assistant\n" - return text, image_data - - def generate(self, text, image_data): - response = requests.post( - self.base_url + "/generate", - json={ - "text": text, - "image_data": image_data, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, - "no_stop_trim": True, - "skip_special_tokens": False, - }, - "modalities": ["multi-images"], - }, - ).json() - return response["text"] - - def generate_for_video(self, batch, num_frame) -> Union[str, list[str]]: - # prepare the video input about Steven introducing ipod nano - url = "https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4" - cache_dir = os.path.expanduser("~/.cache") - file_path = os.path.join(cache_dir, "jobs.mp4") - os.makedirs(cache_dir, exist_ok=True) - if not os.path.exists(file_path): - response = requests.get(url) - response.raise_for_status() - with open(file_path, "wb") as f: - f.write(response.content) - - if not batch: - assert isinstance(num_frame, int) - messages = self.prepare_video_messages(file_path, max_frames_num=num_frame) - text, image_data = self.get_prompt_from_messages(messages) - return self.generate(text, image_data) - else: - assert isinstance(num_frame, list) - func_args = [] - for max_frames_num in num_frame: - messages = self.prepare_video_messages( - file_path, - max_frames_num=max_frames_num, - ) - text, image_data = self.get_prompt_from_messages(messages) - func_args.append((text, image_data)) - - with ThreadPoolExecutor(max_workers=10) as executor: - responses = list(executor.map(lambda p: self.generate(*p), func_args)) - - return responses - - def run_generate(self, chunked_prefill_size, batch, num_frame): - # launch server - model = "lmms-lab/llava-onevision-qwen2-7b-ov" - # model = "meta-llama/Llama-3.2-11B-Vision-Instruct" - self.base_url = DEFAULT_URL_FOR_TEST - process = popen_launch_server( - model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--chunked-prefill-size", - f"{chunked_prefill_size}", - ], - ) - try: - return self.generate_for_video(batch, num_frame) - finally: - kill_process_tree(process.pid) - - def test_chunked_prefill(self): - output_chunked = self.run_generate( - chunked_prefill_size=1024, batch=False, num_frame=1 - ) - output_no_chunked = self.run_generate( - chunked_prefill_size=-1, batch=False, num_frame=1 - ) - - print("output with chunked prefill:") - print(output_chunked) - print("output without chunked prefill:") - print(output_no_chunked) - assert output_chunked == output_no_chunked - - output_chunked = self.run_generate( - chunked_prefill_size=1024, batch=True, num_frame=[2, 6, 8, 10] - ) - output_no_chunked = self.run_generate( - chunked_prefill_size=-1, batch=True, num_frame=[2, 6, 8, 10] - ) - - print("output with chunked prefill:") - print(output_chunked) - print("output without chunked prefill:") - print(output_no_chunked) - assert output_chunked == output_no_chunked - - -if __name__ == "__main__": - unittest.main() From 0639bf15d1077fafe6f1be41dad72d6c87b301a9 Mon Sep 17 00:00:00 2001 From: HAI Date: Mon, 2 Dec 2024 23:20:33 -0800 Subject: [PATCH 05/60] ROCm Container: set SGLANG_SET_CPU_AFFINITY=1 (#2328) --- docker/Dockerfile.rocm | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 7df5e5fcf23..c965d140f06 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -33,6 +33,7 @@ RUN python -m pip cache purge # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 +ENV SGLANG_SET_CPU_AFFINITY=1 ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 ENV NCCL_MIN_NCHANNELS=112 From 83b340e371a0151c9fdefac9f07e0f89ba5e6c37 Mon Sep 17 00:00:00 2001 From: Ata Fatahi Date: Tue, 3 Dec 2024 00:06:25 -0800 Subject: [PATCH 06/60] Add missing license for router wheel (#2324) Signed-off-by: Ata Fatahi --- rust/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/pyproject.toml b/rust/pyproject.toml index 35b1dc7d433..a07ba953adf 100644 --- a/rust/pyproject.toml +++ b/rust/pyproject.toml @@ -9,6 +9,7 @@ description = "SGLang router is a standalone module implemented in Rust to achie authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] requires-python = ">=3.8" readme = "README.md" +license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Rust", From 07ec07ad1fa59e0f07a4fcd1b1f324123c2e2bd4 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 3 Dec 2024 01:58:25 -0800 Subject: [PATCH 07/60] Improve torch compile for fused moe (#2327) --- .../benchmark_torch_compile_fused_moe.py | 7 +++-- python/sglang/srt/layers/fused_moe_patch.py | 31 ++++++++++++------- .../srt/model_executor/cuda_graph_runner.py | 23 +++++++++----- .../sglang/srt/model_executor/model_runner.py | 2 +- test/srt/test_srt_engine.py | 2 +- test/srt/test_torch_compile_moe.py | 4 +-- 6 files changed, 45 insertions(+), 24 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 1f54f9f9f49..1bd6eec1645 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -6,6 +6,7 @@ from transformers import AutoConfig from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton +from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config def get_model_config(model_name: str, tp_size: int): @@ -64,7 +65,7 @@ def fused_topk_native( return topk_weights, topk_ids -@torch.compile +@torch.compile(dynamic=False) def fused_moe_torch( x, w1, @@ -88,7 +89,8 @@ def fused_moe_torch( w13_weights = w1[topk_ids] w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = w2[topk_ids] - x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + x1 = F.silu(x1) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) @@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): print(f"benchmark {provider} with batch_size={batch_size}") torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) + set_torch_compile_config() num_tokens = batch_size num_experts = model_config["num_experts"] diff --git a/python/sglang/srt/layers/fused_moe_patch.py b/python/sglang/srt/layers/fused_moe_patch.py index 400ca03c434..baca2581150 100644 --- a/python/sglang/srt/layers/fused_moe_patch.py +++ b/python/sglang/srt/layers/fused_moe_patch.py @@ -105,20 +105,29 @@ def fused_moe_forward_native( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - assert custom_routing_function is None - topk_weights, topk_ids = select_experts_native( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - ) + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + top_k, + renormalize, + num_expert_group, + topk_group, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + x, router_logits, top_k, renormalize + ) + w13_weights = layer.w13_weight[topk_ids] w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = layer.w2_weight[topk_ids] - x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + x1 = F.silu(x1) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 84f6825c3fd..dd26a77ad65 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -36,7 +36,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner -def _to_torch(model: torch.nn.Module, reverse: bool = False): +def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: @@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): else: # NOTE: Temporarily workaround MoE if "FusedMoE" in sub.__class__.__name__: - sub._forward_method = fused_moe_forward_native + if batch_size == 1: + # The performance of torch.compile on this layer is not always good when bs > 1, + # so we decide to skip it for now. + sub._forward_method = fused_moe_forward_native else: sub._forward_method = sub.forward_native setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): - _to_torch(sub, reverse) + _to_torch(sub, reverse, batch_size) @contextmanager def patch_model( - model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" + model: torch.nn.Module, + enable_compile: bool, + batch_size: int, + tp_group: "GroupCoordinator", ): """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None try: if enable_compile: - _to_torch(model) + _to_torch(model, reverse=False, batch_size=batch_size) monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. @@ -70,13 +76,15 @@ def patch_model( # even with ENABLE_INTRA_NODE_COMM=1. # tp_group.ca_comm = None yield torch.compile( - torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs" + torch.no_grad()(model.forward), + mode="max-autotune-no-cudagraphs", + dynamic=False, ) else: yield model.forward finally: if enable_compile: - _to_torch(model, reverse=True) + _to_torch(model, reverse=True, batch_size=batch_size) monkey_patch_vllm_all_gather(reverse=True) tp_group.ca_comm = backup_ca_comm @@ -237,6 +245,7 @@ def capture(self): with patch_model( self.model_runner.model, bs in self.compile_bs, + bs, self.model_runner.tp_group, ) as forward: ( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 74a7d1fc56c..fafb8783e5a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -622,7 +622,7 @@ def init_cuda_graphs(self): tic = time.time() logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) - logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s") + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index a985c8dda9e..7479b646837 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -188,7 +188,7 @@ def test_8_engine_offline_throughput(self): ) bench_args = BenchArgs(num_prompts=10) result = throughput_test(server_args=server_args, bench_args=bench_args) - self.assertGreater(result["total_throughput"], 3500) + self.assertGreater(result["total_throughput"], 3000) if __name__ == "__main__": diff --git a/test/srt/test_torch_compile_moe.py b/test/srt/test_torch_compile_moe.py index 89d4ed6bdf9..fb78dd7f4b8 100644 --- a/test/srt/test_torch_compile_moe.py +++ b/test/srt/test_torch_compile_moe.py @@ -14,7 +14,7 @@ ) -class TestTorchCompile(unittest.TestCase): +class TestTorchCompileMoe(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST @@ -23,7 +23,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile", "--torch-compile-max-bs", "1"], + other_args=["--enable-torch-compile", "--torch-compile-max-bs", "8"], ) @classmethod From fda628d8f210058b5386d0e6b4eefcd6a8fb8947 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 3 Dec 2024 21:22:19 +0800 Subject: [PATCH 08/60] fix: resolve cmake url for Dockerfile.dev (#2335) --- docker/Dockerfile.dev | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 8bd05289979..79d91d8cda5 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -50,8 +50,8 @@ RUN curl -L https://github.com/clangd/clangd/releases/download/18.1.3/clangd-lin && rm -rf clangd_18.1.3 clangd.zip # Install CMake -RUN curl -L https://cmake.org/download/#:~:text=cmake%2D3.31.1%2Dlinux%2Dx86_64.tar.gz -o cmake.tar.gz \ - && tar -xzf cmake.tar.gz \ +RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1-linux-x86_64.tar.gz \ + && tar -xzf cmake-3.31.1-linux-x86_64.tar.gz \ && cp -r cmake-3.31.1-linux-x86_64/bin/* /usr/local/bin/ \ && cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \ && rm -rf cmake-3.31.1-linux-x86_64 cmake.tar.gz From 1228f7ca69e6ee3f5076f2381c3a187120e0de00 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 3 Dec 2024 07:12:33 -0800 Subject: [PATCH 09/60] Fix gptq for moe layers (#2300) Co-authored-by: root --- .../srt/layers/quantization/__init__.py | 34 +++++++++++++++++++ python/sglang/srt/models/mixtral.py | 12 +++++-- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index a1bacdce036..f34a581d657 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix): return None +def gptq_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod, + GPTQMarlinMoEMethod, + ) + + from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, LinearBase): + return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) + return None + + +def awq_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinLinearMethod, + AWQMoEMethod, + ) + + from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, LinearBase): + return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) + return None + + def apply_monkey_patches(): """Apply all monkey patches in one place.""" setattr(Fp8MoEMethod, "apply", fp8_moe_apply) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) + setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) + setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) # Apply patches when module is imported diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index b222387a776..e75dc1288b7 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -339,7 +339,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] @@ -353,6 +355,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -365,7 +371,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip loading kv_scale from ckpts towards new design. if name.endswith(".kv_scale") and name not in params_dict: From 0495796517a706e6ddf22189359f9da8e6f2b36b Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Tue, 3 Dec 2024 10:27:43 -0800 Subject: [PATCH 10/60] [router] Copy license when publishing & bump version (#2339) --- .github/workflows/release-pypi-router.yml | 3 ++- rust/pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-pypi-router.yml b/.github/workflows/release-pypi-router.yml index f762aad6ce2..fbbb1f0243b 100644 --- a/.github/workflows/release-pypi-router.yml +++ b/.github/workflows/release-pypi-router.yml @@ -69,9 +69,10 @@ jobs: with: path: sglang-repo - - name: Move rust folder to root and delete sglang-repo + - name: Move rust folder to root, copy the license file, and delete sglang-repo run: | mv sglang-repo/rust/* . + mv sglang-repo/LICENSE . rm -rf sglang-repo ls -alt diff --git a/rust/pyproject.toml b/rust/pyproject.toml index a07ba953adf..d1327d9203e 100644 --- a/rust/pyproject.toml +++ b/rust/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang-router" -version = "0.0.10" +version = "0.0.11" description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances." authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] requires-python = ">=3.8" From f8b0326934bacb7a7d4eba68fb6eddebaa6ff751 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 4 Dec 2024 03:55:41 +0800 Subject: [PATCH 11/60] chore: bump v0.4.0 (#2338) --- docker/Dockerfile.dev | 2 +- docker/Dockerfile.rocm | 2 +- docs/developer/setup_github_runner.md | 4 ++-- docs/start/install.md | 10 +++++----- python/pyproject.toml | 2 +- python/sglang/version.py | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 79d91d8cda5..d79dabeb5fc 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -54,7 +54,7 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1 && tar -xzf cmake-3.31.1-linux-x86_64.tar.gz \ && cp -r cmake-3.31.1-linux-x86_64/bin/* /usr/local/bin/ \ && cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \ - && rm -rf cmake-3.31.1-linux-x86_64 cmake.tar.gz + && rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz # Add yank script COPY --chown=root:root <<-"EOF" /usr/local/bin/yank diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index c965d140f06..e51afce4d3b 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,5 +1,5 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.3.6.post3 -t v0.3.6.post3-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.0 -t v0.4.0-rocm620 -f Dockerfile.rocm . # default base image ARG BASE_IMAGE="rocm/vllm-dev:20241022" diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index 3c73c0da0af..d9eeb626583 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.3.6.post3-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0-rocm620 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.3.6.post3-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0-rocm620 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/start/install.md b/docs/start/install.md index 3f6a816412a..a5e5d73f561 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -13,7 +13,7 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.3.6.post3 https://github.com/sgl-project/sglang.git +git clone -b v0.4.0 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -26,7 +26,7 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.3.6.post3 https://github.com/sgl-project/sglang.git +git clone -b v0.4.0 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -51,7 +51,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.3.6.post3 -t v0.3.6.post3-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.0 -t v0.4.0-rocm620 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -60,11 +60,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.3.6.post3-rocm620 \ + v0.4.0-rocm620 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.3.6.post3-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.0-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/python/pyproject.toml b/python/pyproject.toml index 1ecfc4fa50a..908f556056c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.3.6.post3" +version = "0.4.0" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" diff --git a/python/sglang/version.py b/python/sglang/version.py index 4f6e29b6637..6a9beea82f6 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.3.6.post3" +__version__ = "0.4.0" From b2986d7aa5a40740b71c0d2f59a9277cfa10c67f Mon Sep 17 00:00:00 2001 From: HAI Date: Wed, 4 Dec 2024 03:01:33 -0800 Subject: [PATCH 12/60] Adding SGLang FP8 Utils (#2348) --- .../srt/layers/quantization/fp8_utils.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/fp8_utils.py diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py new file mode 100644 index 00000000000..3ba381a373f --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -0,0 +1,27 @@ +from typing import Optional, Tuple + +import torch + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale From eb0c1f53735c2a6f4c0ae0f0846f7cdc959ebada Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 5 Dec 2024 01:24:51 +0800 Subject: [PATCH 13/60] docs: add SGLang v0.4 blog (#2341) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 45967ee5838..43c2f8c8808 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ [**Join Bi-Weekly Development Meeting**](https://docs.google.com/document/d/1xEow4eIM152xNcRxqZz9VEcOiTQo8-CEuuQ5qTmkt-E/edit?usp=sharing) | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | ## News +- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). - [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). - [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). - [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). @@ -47,7 +48,7 @@ The core features include: - [Frontend: Structured Generation Language (SGLang)](https://sgl-project.github.io/frontend/frontend.html) ## Benchmark And Performance -Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/) +Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/) ## Roadmap [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) From ec52464ddeabcc70b1fd3117b93adfefd5cb7ed0 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Thu, 5 Dec 2024 01:50:28 +0800 Subject: [PATCH 14/60] MLA prefill w/o weight absorption (#2349) --- .../sglang/srt/layers/attention/__init__.py | 7 +- .../attention/double_sparsity_backend.py | 30 +++++--- .../layers/attention/flashinfer_backend.py | 25 +++++-- .../layers/attention/torch_native_backend.py | 30 +++++--- .../srt/layers/attention/triton_backend.py | 30 +++++--- .../attention/triton_ops/extend_attention.py | 3 + python/sglang/srt/layers/radix_attention.py | 6 +- python/sglang/srt/models/deepseek_v2.py | 71 ++++++++++++++++++- 8 files changed, 166 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index f5d573f5f7b..a70e9537bfe 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -52,12 +52,13 @@ def forward( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run forward on an attention layer.""" if forward_batch.forward_mode.is_decode(): - return self.forward_decode(q, k, v, layer, forward_batch) + return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache) else: - return self.forward_extend(q, k, v, layer, forward_batch) + return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache) def forward_decode( self, @@ -66,6 +67,7 @@ def forward_decode( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run a forward for decode.""" raise NotImplementedError() @@ -77,6 +79,7 @@ def forward_extend( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run a forward for extend.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 73c32df8f6e..856aa984c38 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -165,7 +165,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 1 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: @@ -181,9 +187,10 @@ def forward_extend( .expand(k.shape[0], -1, -1), ) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v, k_label - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v, k_label + ) ( start_loc, @@ -212,7 +219,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -242,9 +255,10 @@ def forward_decode( .expand(k.shape[0], -1, -1), ) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v, k_label - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v, k_label + ) # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num # and set a minimum value for sparse_decode diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 258659efa2a..f89bc2ccaa2 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -221,7 +221,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 0 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): prefill_wrapper_paged = self.prefill_wrappers_paged[ self._get_wrapper_idx(layer) @@ -237,7 +243,8 @@ def forward_extend( if not use_ragged: if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -270,12 +277,19 @@ def forward_extend( o, _ = merge_state(o1, s1, o2, s2) - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] cache_loc = ( @@ -286,7 +300,8 @@ def forward_decode( if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), diff --git a/python/sglang/srt/layers/attention/torch_native_backend.py b/python/sglang/srt/layers/attention/torch_native_backend.py index 4ccad2216f7..5e7e0e66e22 100644 --- a/python/sglang/srt/layers/attention/torch_native_backend.py +++ b/python/sglang/srt/layers/attention/torch_native_backend.py @@ -216,16 +216,23 @@ def _run_sdpa_forward_decode( return output def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num @@ -249,7 +256,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -260,9 +273,10 @@ def forward_decode( else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index b9597b3ea41..1b7c4c46d26 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -114,7 +114,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 1 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: @@ -122,9 +128,10 @@ def forward_extend( else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata self.extend_attention_fwd( @@ -146,7 +153,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -160,9 +173,10 @@ def forward_decode( start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 56cc439c31e..b7afd62e723 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -284,6 +284,9 @@ def extend_attention_fwd( elif Lq == 288: BLOCK_DMODEL = 256 BLOCK_DPE = 32 + elif Lq == 192: + BLOCK_DMODEL = 128 + BLOCK_DPE = 64 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 5d8c6470178..1df29ec68a9 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -48,11 +48,13 @@ def __init__( self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention - def forward(self, q, k, v, forward_batch: ForwardBatch): + def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True): if k is not None: # For cross-layer sharing, kv can be None assert v is not None k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim) - return forward_batch.attn_backend.forward(q, k, v, self, forward_batch) + return forward_batch.attn_backend.forward( + q, k, v, self, forward_batch, save_kv_cache + ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 424f86aec28..e83774ff55e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -453,7 +453,7 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.attn = RadixAttention( + self.attn_mqa = RadixAttention( self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim, self.scaling, @@ -462,6 +462,15 @@ def __init__( v_head_dim=self.kv_lora_rank, ) + self.attn_mha = RadixAttention( + self.num_local_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + layer_id=layer_id, + v_head_dim=self.v_head_dim, + ) + self.w_kc = None self.w_vc = None self.w_scale = None @@ -471,6 +480,63 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + ) -> torch.Tensor: + # Use normal computation for prefill and use weight absorption for extend/decode + if ( + forward_batch.forward_mode.is_extend() + and forward_batch.extend_prefix_lens.sum() == 0 + ): + return self.forward_normal(positions, hidden_states, forward_batch) + else: + return self.forward_absorb(positions, hidden_states, forward_batch) + + def forward_normal( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope = kv[..., : self.qk_nope_head_dim] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q[..., self.qk_nope_head_dim :] = q_pe + k = torch.empty_like(q) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + + latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) + latent_cache[:, :, self.kv_lora_rank :] = k_pe + + # Save latent cache + forward_batch.token_to_kv_pool.set_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, latent_cache, None + ) + attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) + attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + def forward_absorb( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( @@ -508,7 +574,7 @@ def forward( q_input[..., self.kv_lora_rank :] = q_pe k_input[..., self.kv_lora_rank :] = k_pe - attn_output = self.attn(q_input, k_input, v_input, forward_batch) + attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) if self.w_vc.dtype == torch.float8_e4m3fn: @@ -835,7 +901,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self_attn.w_vc = w_vc.contiguous().transpose(1, 2) if hasattr(self_attn.kv_b_proj, "weight_scale"): self_attn.w_scale = self_attn.kv_b_proj.weight_scale - del self_attn.kv_b_proj EntryClass = DeepseekV2ForCausalLM From ed45e509df91663698f42d132253ae485baba00c Mon Sep 17 00:00:00 2001 From: Ata Fatahi Date: Wed, 4 Dec 2024 09:53:02 -0800 Subject: [PATCH 15/60] Check gpu availability at server args creation (#2340) Signed-off-by: Ata Fatahi --- python/sglang/srt/server_args.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 788686a1ee8..7b337500fd7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,6 +20,8 @@ import tempfile from typing import List, Optional +import torch + from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.utils import ( get_amdgpu_memory_capacity, @@ -151,8 +153,11 @@ def __post_init__(self): if is_hip(): gpu_mem = get_amdgpu_memory_capacity() - else: + elif torch.cuda.is_available(): gpu_mem = get_nvgpu_memory_capacity() + else: + # GPU memory is not known yet or no GPU is available. + gpu_mem = None # Set mem fraction static, which depends on the tensor parallelism size if self.mem_fraction_static is None: @@ -169,14 +174,14 @@ def __post_init__(self): # Set chunked prefill size, which depends on the gpu memory capacity if self.chunked_prefill_size is None: - if gpu_mem < 25_000: + if gpu_mem is not None and gpu_mem < 25_000: self.chunked_prefill_size = 2048 else: self.chunked_prefill_size = 8192 # Set cuda graph max batch size if self.cuda_graph_max_bs is None: - if gpu_mem < 25_000: + if gpu_mem is not None and gpu_mem < 25_000: self.cuda_graph_max_bs = 8 else: self.cuda_graph_max_bs = 160 From 2db4469808158700036de79bd41a9c463bb89bdc Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 5 Dec 2024 02:00:34 +0800 Subject: [PATCH 16/60] minor: limit the range of vllm versions (#2350) --- python/pyproject.toml | 2 +- python/sglang/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 908f556056c..1452fad4ab7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -23,7 +23,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", "xgrammar>=0.1.4"] -srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1", "cuda-python", "flashinfer>=0.1.6"] +srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 40fbf17bc9d..de9134857a6 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -66,7 +66,7 @@ __all__ += ["__version__"] -# SGL Backends +# SGLang Backends from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.utils import LazyImport From 786be44da52e4994c499fdddbbac0f5d79a9fd6e Mon Sep 17 00:00:00 2001 From: Chayenne Date: Wed, 4 Dec 2024 11:19:46 -0800 Subject: [PATCH 17/60] Fix Docs CI When Compile Error (#2323) --- docs/Makefile | 2 +- docs/backend/native_api.ipynb | 10 +++++----- python/sglang/srt/server.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index 50f77a30c09..13d81f4f847 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -19,7 +19,7 @@ compile: echo "Executing $$nb"; \ jupyter nbconvert --to notebook --execute --inplace "$$nb" \ --ExecutePreprocessor.timeout=600 \ - --ExecutePreprocessor.kernel_name=python3; \ + --ExecutePreprocessor.kernel_name=python3 || exit 1; \ fi; \ done diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 7207259ea3c..26758f7f975 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -220,19 +220,19 @@ "metadata": {}, "outputs": [], "source": [ - "# failed update with different parameter size\n", + "# failed update with different parameter size or wrong name\n", "\n", "url = \"http://localhost:30010/update_weights_from_disk\"\n", - "data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n", + "data = {\"model_path\": \"meta-llama/Llama-3.2-1B-wrong\"}\n", "\n", "response = requests.post(url, json=data)\n", "response_json = response.json()\n", "print_highlight(response_json)\n", "assert response_json[\"success\"] is False\n", "assert response_json[\"message\"] == (\n", - " \"Failed to update weights: The size of tensor a (2048) must match \"\n", - " \"the size of tensor b (3072) at non-singleton dimension 1.\\n\"\n", - " \"Rolling back to original weights.\"\n", + " \"Failed to get weights iterator: \"\n", + " \"meta-llama/Llama-3.2-1B-wrong\"\n", + " \" (repository not found).\"\n", ")" ] }, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index fc8ac150b3f..7b91cb69797 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -329,7 +329,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ) -@app.api_route("/encode", methods=["POST", "PUT"]) +@app.api_route("/classify", methods=["POST", "PUT"]) @time_func_latency async def classify_request(obj: EmbeddingReqInput, request: Request): """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" From 18ea841f408c01a28c1a1db92f37ae95cfa12523 Mon Sep 17 00:00:00 2001 From: Chayenne Date: Wed, 4 Dec 2024 15:41:22 -0800 Subject: [PATCH 18/60] Add Docs For SGLang Native Router (#2308) --- docs/index.rst | 7 ++ docs/router/router.md | 110 ++++++++++++++++++ .../trace_and_evaluate_rag_using_parea.ipynb | 76 +----------- 3 files changed, 121 insertions(+), 72 deletions(-) create mode 100644 docs/router/router.md diff --git a/docs/index.rst b/docs/index.rst index 873999d2587..8c6c018c4ce 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,6 +39,13 @@ The core features include: frontend/choices_methods.md +.. toctree:: + :maxdepth: 1 + :caption: SGLang Router + + router/router.md + + .. toctree:: :maxdepth: 1 :caption: References diff --git a/docs/router/router.md b/docs/router/router.md new file mode 100644 index 00000000000..11ea8c59065 --- /dev/null +++ b/docs/router/router.md @@ -0,0 +1,110 @@ +# Router for Data Parallelism + +Given multiple GPUs running multiple SGLang Runtimes, SGLang Router distributes the requests to different Runtimes with its unique cache-aware load-balancing algorithm. + +The router is a independent Python package, and it can be used as a drop-in replacement for the OpenAI API. + +## Installation + +```bash +pip install sglang-router +``` + +Detailed usage of the router can be found in [launch_router](https://github.com/sgl-project/sglang/blob/main/rust/py_src/sglang_router/launch_router.py) and [launch_server](https://github.com/sgl-project/sglang/blob/main/rust/py_src/sglang/launch_server.py). Also, you can directly run the following command to see the usage of the router. + +```bash +python -m sglang_router.launch_server --help +python -m sglang_router.launch_routher --help +``` + +The router supports two working modes: + +1. Co-launch Router and Runtimes +2. Launch Runtimes and Router separately + +## Co-launch Router and Runtimes + +This will be a drop-in replacement for the existing `--dp-size` arguement of SGLang Runtime. Under the hood, it uses multi-processes to launch multiple workers, wait for them to be ready, then connect the router to all workers. + +```bash +python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dp-size 1 +``` + +After the server is ready, you can directly send requests to the router as the same way as sending requests to each single worker. + +```python +import requests + +url = "http://localhost:30000/generate" +data = {"text": "What is the capital of France?"} + +response = requests.post(url, json=data) +print(response.json()) +``` + +## Launch Runtimes and Router Separately + +This is useful for multi-node DP. First, launch workers on multiple nodes, then launch a router on the main node, and connect the router to all workers. + +```bash +python -m sglang_router.launch_router --worker-urls http://worker_url_1 http://worker_url_2 +``` + +## Strategies + +### Cache-Aware Load-Balancing Router + +The native router combines two strategies to optimize both cache utilization and request distribution: + +1. Cache-Aware Routing (Approximate Tree) +2. Load-Balancing Routing (Shortest Queue with Balance Thresholds) + +The router dynamically switches between these strategies based on load conditions: + +- Uses load balancing when the system is imbalanced +- Uses cache-aware routing when the system is balanced + +A system is considered imbalanced if both conditions are met: + +1. (max_load - min_load) > balance_abs_threshold +2. max_load > balance_rel_threshold * min_load + +***Cache-Aware Routing (Approximate Tree)*** + +When the workers are considered to be balanced, the router maintains an approximate radix tree for each worker based on request history, eliminating the need for direct cache state queries on each worker. The tree stores raw text characters instead of token IDs to avoid tokenization overhead. + +Process: + +1. For each request, find the worker with the highest prefix match. + + - If match rate > cache_threshold, route the request to the worker with highest match (likely has relevant data cached) + - If match rate ≤ cache_threshold, route the request to the worker with smallest tree size (most available cache capacity) + +2. Background maintenance: Periodically evict least recently used leaf nodes on the approximate tree to prevent memory overflow. + +***Load-Balancing (Shortest Queue)*** + +For unbalanced systems, this strategy tracks pending request counts per worker and routes new requests to the least busy worker. This helps maintain optimal load distribution across workers. + +## Configuration Parameters + +1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5) + - Minimum prefix match ratio to use highest-match routing. + - Below this threshold, the request will be routed to the worker with most available cache space. + +2. `balance_abs_threshold`: (integer, default: 32) + - Absolute difference threshold for load imbalance detection. + - The system is potentially imbalanced if (max_load - min_load) > abs_threshold. + +3. `balance_rel_threshold`: (float, default: 1.0001) + - Relative ratio threshold for load imbalance detection. + - The system is potentially imbalanced if max_load > min_load * rel_threshold. + - Used in conjunction with `balance_abs_threshold` to determine the final imbalance state. + +4. `eviction_interval`: (integer, default: 60) + - Interval in seconds between LRU eviction cycles for the approximate trees. + - Background thread periodically evicts least recently used nodes to maintain tree size. + +5. `max_tree_size`: (integer, default: 16777216) + - Maximum nodes on the approximate tree. + - When exceeded, LRU leaf nodes are evicted during the next eviction cycle. diff --git a/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb b/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb index a54f92bd6aa..3c1b2a6c400 100644 --- a/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb +++ b/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb @@ -177,18 +177,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'The World Health Organization formally declared an end to the COVID-19 global health emergency'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "@trace\n", "def rag_pipeline(question: str) -> str:\n", @@ -307,18 +296,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'The World Health Organization formally declared an end to the COVID-19 global health emergency in May 2023.'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "@trace\n", "def rag_pipeline(question: str) -> str:\n", @@ -355,15 +333,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: nest-asyncio in /Users/joschkabraun/miniconda3/envs/sglang/lib/python3.10/site-packages (1.6.0)\r\n" - ] - } - ], + "outputs": [], "source": [ "!pip install nest-asyncio\n", "import nest_asyncio\n", @@ -382,45 +352,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Run name set to: sneak-weal, since a name was not provided.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 100/100 [00:27<00:00, 3.63it/s]\n", - "Waiting for evaluations to finish: 100%|██████████| 19/19 [00:10<00:00, 1.89it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Experiment RAG Run sneak-weal stats:\n", - "{\n", - " \"latency\": \"2.69\",\n", - " \"input_tokens\": \"61.26\",\n", - " \"output_tokens\": \"75.88\",\n", - " \"total_tokens\": \"137.14\",\n", - " \"cost\": \"0.00\",\n", - " \"answer_context_faithfulness_statement_level\": \"0.26\",\n", - " \"answer_matches_target_llm_grader\": \"0.22\",\n", - " \"context_query_relevancy\": \"0.27\",\n", - " \"percent_target_supported_by_context\": \"0.40\"\n", - "}\n", - "\n", - "\n", - "View experiment & traces at: https://app.parea.ai/experiments/RAG/30f0244a-d56c-44ff-bdfb-8f47626304b6\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "e = p.experiment(\n", " \"RAG\",\n", From d693ec0427bd70c8676316c634e00bd27514b7ec Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 4 Dec 2024 17:26:00 -0800 Subject: [PATCH 19/60] Make torch TP composable with torch.compile (#2352) --- python/sglang/srt/model_parallel.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 6817bce0235..778347b8ef3 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -54,11 +54,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me )._prepare_output_fn( output_layouts, use_local_output, mod, outputs, device_mesh ) - # wait for the output to be ready - if isinstance(outputs, AsyncCollectiveTensor): - return outputs.wait() - else: - return outputs + return torch.distributed._functional_collectives.wait_tensor(outputs) def tensor_parallel( From 9cc733b38ceb4fc9df0daa6aed7335f2f8a4ba82 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 17:26:42 -0800 Subject: [PATCH 20/60] move apply_torchao_config_ to model_runner (#2342) --- python/sglang/srt/layers/torchao_utils.py | 58 ++++++------------- .../sglang/srt/model_executor/model_runner.py | 8 +++ python/sglang/srt/models/grok.py | 5 -- python/sglang/srt/models/llama.py | 5 -- python/sglang/srt/models/mixtral.py | 5 -- python/sglang/srt/models/phi3_small.py | 5 -- python/sglang/srt/models/qwen2_moe.py | 5 -- .../sglang/srt/models/torch_native_llama.py | 5 -- 8 files changed, 25 insertions(+), 71 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 9395cdf271b..3f886221cca 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -7,13 +7,15 @@ import torch -def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): - """Quantize a Tensor with torchao quantization specified by torchao_config +def apply_torchao_config_to_model_( + model: torch.nn.Module, torchao_config: str, filter_fn=None +): + """Quantize a modelwith torchao quantization specified by torchao_config Args: - `param`: weight parameter of the linear module - `torchao_config`: type of quantization and their arguments we want to use to - quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size + `model`: a model to be quantized based on torchao_config + `torchao_config` (str): type of quantization and their arguments we want to use to + quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size 128 """ # Lazy import to suppress some warnings @@ -26,12 +28,12 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): ) from torchao.quantization.observer import PerRow, PerTensor - dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) - dummy_linear.weight = param - if "int8wo" in torchao_config: - quantize_(dummy_linear, int8_weight_only()) + if torchao_config == "" or torchao_config is None: + return model + elif "int8wo" in torchao_config: + quantize_(model, int8_weight_only(), filter_fn=filter_fn) elif "int8dq" in torchao_config: - quantize_(dummy_linear, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn) elif "int4wo" in torchao_config: group_size = int(torchao_config.split("-")[-1]) assert group_size in [ @@ -40,13 +42,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): 128, 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" - quantize_(dummy_linear, int4_weight_only(group_size=group_size)) + quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) elif "fp8wo" in torchao_config: from torchao.quantization import float8_weight_only # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - quantize_(dummy_linear, float8_weight_only()) + quantize_(model, float8_weight_only(), filter_fn=filter_fn) elif "fp8dq" in torchao_config: granularity = torchao_config.split("-")[-1] GRANULARITY_MAP = { @@ -57,39 +59,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): granularity in GRANULARITY_MAP ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}" quantize_( - dummy_linear, + model, float8_dynamic_activation_float8_weight( granularity=GRANULARITY_MAP[granularity] ), + filter_fn=filter_fn, ) else: raise ValueError(f"Unexpected config: {torchao_config}") - return dummy_linear.weight - - -def apply_torchao_config_( - self: torch.nn.Module, - params_dict: Dict[str, torch.Tensor], - param_suffixes: Set[str], -) -> None: - """A util function used for quantizing the weight parameters after they are loaded if - self.torchao_config is specified - - Args: - `self`: the model we want to quantize - `params_dict`: dictionary mapping from param_name to the parameter Tensor - `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes - - Returns: - None, the `params_dict` is modified inplace and the weights of `self` model are quantized - """ - if self.torchao_config: - for param_suffix in param_suffixes: - for name in params_dict: - param = params_dict[name] - if param_suffix in name and param.ndim == 2: - params_dict[name] = torchao_quantize_param_data( - param, self.torchao_config - ) - self.load_state_dict(params_dict, assign=True) + return model diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fafb8783e5a..6f79afb70ba 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -38,6 +38,7 @@ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -159,6 +160,13 @@ def __init__( else: self.torch_tp_applied = False + def filter_fn(module, fqn): + return "proj" in fqn + + apply_torchao_config_to_model_( + self.model, global_server_args_dict["torchao_config"], filter_fn + ) + # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager() diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 956f73b1482..2b52e2b1bcc 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -35,12 +35,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -290,7 +288,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -374,8 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - class Grok1ModelForCausalLM(Grok1ForCausalLM): """An alias for backward-compatbility.""" diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 61409a9eaeb..e3e44ea6ffc 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -36,12 +36,10 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -304,7 +302,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) # Llama 3.2 1B Insturct set tie_word_embeddings to True # Llama 3.1 8B Insturct set tie_word_embeddings to False @@ -424,8 +421,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - def get_weights_by_name( self, name: str, truncate_size: int = 100, tp_size: int = 1 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index e75dc1288b7..f1ae1f57a3d 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -34,12 +34,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -295,7 +293,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -387,7 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 6340330774d..1e70c7d7874 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -17,13 +17,11 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -348,7 +346,6 @@ def __init__( quant_config=quant_config, prefix="model", ) - self.torchao_config = global_server_args_dict["torchao_config"] self.vocab_size = config.vocab_size self.mup_width_multiplier = config.mup_width_multiplier self.lm_head = ParallelLMHead( @@ -441,7 +438,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = Phi3SmallForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 0094cb8c3e2..62cd3281d03 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -40,12 +40,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -352,7 +350,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = Qwen2MoeModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config @@ -445,7 +442,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 25e555484a7..7a55d50457a 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -58,12 +58,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -392,7 +390,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) if self.config.tie_word_embeddings: @@ -503,8 +500,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass From 2b0fc5941d3d7f3dfe4a56c053ddddf9d4f77670 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 4 Dec 2024 19:02:08 -0800 Subject: [PATCH 21/60] [Minor] Code style improvements (#2355) --- python/sglang/srt/layers/torchao_utils.py | 12 +++++++----- .../srt/model_executor/cuda_graph_runner.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 18 ++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 3f886221cca..910309da973 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -2,12 +2,10 @@ Common utilities for torchao. """ -from typing import Dict, Set - import torch -def apply_torchao_config_to_model_( +def apply_torchao_config_to_model( model: torch.nn.Module, torchao_config: str, filter_fn=None ): """Quantize a modelwith torchao quantization specified by torchao_config @@ -21,6 +19,7 @@ def apply_torchao_config_to_model_( # Lazy import to suppress some warnings from torchao.quantization import ( float8_dynamic_activation_float8_weight, + float8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, @@ -28,6 +27,11 @@ def apply_torchao_config_to_model_( ) from torchao.quantization.observer import PerRow, PerTensor + if filter_fn is None: + + def filter_fn(module, fqn): + return "proj" in fqn + if torchao_config == "" or torchao_config is None: return model elif "int8wo" in torchao_config: @@ -44,8 +48,6 @@ def apply_torchao_config_to_model_( ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) elif "fp8wo" in torchao_config: - from torchao.quantization import float8_weight_only - # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 quantize_(model, float8_weight_only(), filter_fn=filter_fn) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index dd26a77ad65..3aac4965a5d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): if "FusedMoE" in sub.__class__.__name__: if batch_size == 1: # The performance of torch.compile on this layer is not always good when bs > 1, - # so we decide to skip it for now. + # so we decide to only use torch.compile when bs =1 sub._forward_method = fused_moe_forward_native else: sub._forward_method = sub.forward_native diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6f79afb70ba..4eaedbccbff 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -27,7 +27,6 @@ initialize_model_parallel, set_custom_all_reduce, ) -from vllm.distributed.parallel_state import in_the_same_node_as from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig @@ -38,7 +37,7 @@ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler -from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ +from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -112,11 +111,13 @@ def __init__( ) if self.is_multimodal: - logger.info( - "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." - ) server_args.chunked_prefill_size = -1 self.mem_fraction_static *= 0.95 + logger.info( + f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} " + f"and turn off chunked prefill " + f"because this is a multimodal model." + ) # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" @@ -160,11 +161,8 @@ def __init__( else: self.torch_tp_applied = False - def filter_fn(module, fqn): - return "proj" in fqn - - apply_torchao_config_to_model_( - self.model, global_server_args_dict["torchao_config"], filter_fn + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] ) # Init memory pool and attention backends From 4a63c181f19015a0a8812b1fe5c33daf90ec8590 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Fri, 6 Dec 2024 00:46:48 +0800 Subject: [PATCH 22/60] Fix AWQ with enable MLA (#2364) --- python/sglang/srt/models/deepseek_v2.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e83774ff55e..80db9a35c71 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -21,6 +21,7 @@ import torch from torch import nn from transformers import PretrainedConfig +from vllm import _custom_ops as ops from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -894,7 +895,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if not global_server_args_dict["disable_mla"]: for layer_id in range(self.config.num_hidden_layers): self_attn = self.model.layers[layer_id].self_attn - w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) From 71e2a27753fa6908eeaa0151ad27df0b05fd407a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 5 Dec 2024 13:42:47 -0800 Subject: [PATCH 23/60] Fix the cuda graph capture range for small #max-running-requests (#2359) --- .../sglang/srt/model_executor/cuda_graph_runner.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 3aac4965a5d..27043cc9a7d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -130,6 +130,20 @@ def __init__(self, model_runner: "ModelRunner"): self.capture_bs = list(range(1, 32)) + [64, 128] else: self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + + if max(self.capture_bs) > model_runner.req_to_token_pool.size: + # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests + # is very samll. We add more values here to make sure we capture the maximum bs. + self.capture_bs = list( + sorted( + set( + self.capture_bs + + [model_runner.req_to_token_pool.size - 1] + + [model_runner.req_to_token_pool.size] + ) + ) + ) + self.capture_bs = [ bs for bs in self.capture_bs From 64fceab8afae962ac2f64b6491d873591a58c051 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Thu, 5 Dec 2024 17:46:21 -0800 Subject: [PATCH 24/60] [router] use 2-gpu-runner (#2368) --- .github/workflows/pr-test-rust.yml | 2 +- rust/py_test/test_launch_server.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 92bd986a077..0df81b487b5 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -42,7 +42,7 @@ jobs: e2e-rust: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: 1-gpu-runner + runs-on: 2-gpu-runner steps: - name: Checkout code uses: actions/checkout@v3 diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index f39b341df2b..a7a695aa9f6 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -92,7 +92,11 @@ def test_mmlu(self): ) metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.65) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) if __name__ == "__main__": From 3d32e4a32c4cd0c29da176bbc9f6b4f018c54fa5 Mon Sep 17 00:00:00 2001 From: xiaobochen <35516720+xiaobochen123@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:05:21 +0800 Subject: [PATCH 25/60] Resubmit MoE-EP (#2371) --- .github/workflows/pr-test.yml | 6 + python/sglang/srt/layers/ep_moe/__init__.py | 0 python/sglang/srt/layers/ep_moe/kernels.py | 349 +++++++++ python/sglang/srt/layers/ep_moe/layer.py | 661 ++++++++++++++++++ python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/models/deepseek_v2.py | 8 +- python/sglang/srt/models/mixtral.py | 18 +- python/sglang/srt/server_args.py | 23 + test/srt/test_moe_ep.py | 113 +++ 10 files changed, 1172 insertions(+), 8 deletions(-) create mode 100644 python/sglang/srt/layers/ep_moe/__init__.py create mode 100644 python/sglang/srt/layers/ep_moe/kernels.py create mode 100644 python/sglang/srt/layers/ep_moe/layer.py create mode 100644 test/srt/test_moe_ep.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 59f0006e128..49c6ec88327 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -105,6 +105,12 @@ jobs: cd test/srt python3 test_update_weights_from_distributed.py + - name: Evaluate MoE EP accuracy (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_moe_ep.py + performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner diff --git a/python/sglang/srt/layers/ep_moe/__init__.py b/python/sglang/srt/layers/ep_moe/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/ep_moe/kernels.py new file mode 100644 index 00000000000..e0486891aa7 --- /dev/null +++ b/python/sglang/srt/layers/ep_moe/kernels.py @@ -0,0 +1,349 @@ +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + + +@triton.jit +def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): + expert = tl.program_id(0) + low = 0 + high = num_toks - 1 + target_location = -1 + while low <= high: + mid = (low + high) // 2 + + if tl.load(reorder_topk_ids + mid) > expert: + high = mid - 1 + else: + low = mid + 1 + target_location = mid + tl.store(seg_indptr + expert + 1, target_location + 1) + + +@triton.jit +def compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + tl.store(src2dst + src_id, dst_id, mask=mask) + + +def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): + reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) + seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) + + compute_seg_indptr_triton_kernel[(num_experts,)]( + reorder_topk_ids, seg_indptr, topk_ids.numel() + ) + + BLOCK_SIZE = 512 + grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) + compute_src2dst_triton_kernel[grid]( + reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE + ) + return reorder_topk_ids, src2dst, seg_indptr + + +@triton.jit +def pre_reorder_triton_kernel( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + OutDtype = gateup_input_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr = input_ptr + src_idx * hidden_size + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + if a1_scales_ptr is not None: + scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) + else: + scale = 1.0 + + dst_idx = tl.load(src2dst_ptr + idx) + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) + out_data = (in_data * scale).to(OutDtype) + tl.store(dst_ptr + offset, out_data, mask=mask) + + +@triton.jit +def silu_and_mul_triton_kernel( + gateup_output, + down_input, + hidden_size, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + BLOCK_SIZE: tl.constexpr, +): + InDtype = gateup_output.dtype.element_ty + OutDtype = down_input.dtype.element_ty + + half_hidden_size = hidden_size // 2 + + pid = tl.program_id(0) + expert_id = tl.load(reorder_topk_ids + pid) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + gateup_output_ptr = gateup_output + pid * hidden_size + gate_output_ptr = gateup_output_ptr + up_output_ptr = gateup_output_ptr + half_hidden_size + down_input_ptr = down_input + pid * half_hidden_size + + if scales is not None: + scale = tl.load(scales + expert_id - start_expert_id) + scale = (1 / scale).to(InDtype) + else: + scale = 1 + + for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < half_hidden_size + + gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) + up_output = tl.load(up_output_ptr + offset, mask=mask) + + # silu & mul & quantize + gate_output = gate_output * tl.sigmoid(gate_output) + gate_output = gate_output.to(InDtype) + + silu_mul_output = gate_output * up_output * scale + silu_mul_output = silu_mul_output.to(OutDtype) + tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) + + +@triton.jit +def post_reorder_triton_kernel( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + InDtype = down_output_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk + + computed = False + store_ptr = output_ptr + src_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + computed = True + dst_idx = tl.load(src2dst_ptr + idx) + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) + + if computed == False: + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + tl.store( + store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask + ) + + +@triton.jit +def compute_m_range( + pid, + batch_size, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + BLOCK_SIZE_M: tl.constexpr, +): + idx = 0 + for bs in range(batch_size): + tiles = tl.load(m_num_tiles_indptr + bs) + if pid >= tiles: + idx = bs + + idx_start = tl.load(m_num_tiles_indptr + idx) + + m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M + m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) + expert_id = tl.load(weight_indices + idx) + return m_range_start, m_range_end, expert_id + + +@triton.jit +def grouped_gemm_triton_kernel( + a, + b, + c, + batch_size, + N, + K, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + use_fp8_w8a8, + scale_a, + scale_b, + a_stride_0: tl.constexpr, + b_stride_0: tl.constexpr, + b_stride_1: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + c_dtype = c.dtype.element_ty + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + total_m_block = tl.load(m_num_tiles_indptr + batch_size) + if pid_m >= total_m_block: + return + + m_range_start, m_range_end, expert_id = compute_m_range( + pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M + ) + if m_range_end - m_range_start == 0: + return + + n_range_start = pid_n * BLOCK_SIZE_N + n_range_end = min(n_range_start + BLOCK_SIZE_N, N) + + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) + offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] + b_ptr = b + ( + (expert_id * b_stride_0) + + (n_range_start + offs_bn[:, None]) * b_stride_1 + + offs_k[None, :] + ) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_tile = tl.load( + a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + b_tile = tl.load( + b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + accumulator = tl.dot(a_tile, b_tile.T, accumulator) + a_ptr += BLOCK_SIZE_K + b_ptr += BLOCK_SIZE_K + + if use_fp8_w8a8: + scale_a_value = tl.load(scale_a + expert_id) + scale_b_value = tl.load(scale_b + expert_id) + accumulator *= scale_a_value * scale_b_value + c_tile = accumulator.to(c_dtype) + + offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) + c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] + c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) + tl.store(c_ptr, c_tile, mask=c_mask) + + +@triton.jit +def compute_m_num_tiles_indptr( + m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr +): + for bs in range(batch_size): + m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) + cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) + pre_num_tiles = tl.load(m_num_tiles_indptr + bs) + tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) + + +def grouped_gemm_triton( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, +): + assert weight_column_major == True # TODO: more + if use_fp8_w8a8: + assert scale_a is not None and scale_b is not None + + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + } + + m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) + compute_m_num_tiles_indptr[(1,)]( + m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] + ) + + grid = lambda META: ( + triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, + triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), + ) + + grouped_gemm_triton_kernel[grid]( + a, + b, + c, + batch_size, + b.size(1), + b.size(2), + seg_indptr, + weight_indices, + m_num_tiles_indptr, + use_fp8_w8a8, + scale_a, + scale_b, + a.stride(0), + b.stride(0), + b.stride(1), + **config, + ) + return c diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py new file mode 100644 index 00000000000..eca119845a7 --- /dev/null +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -0,0 +1,661 @@ +import logging +from typing import Callable, List, Optional, Tuple + +import torch +from torch.nn import Module +from vllm import _custom_ops as ops +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod + +from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.layers.ep_moe.kernels import ( + grouped_gemm_triton, + post_reorder_triton_kernel, + pre_reorder_triton_kernel, + run_moe_ep_preproess, + silu_and_mul_triton_kernel, +) +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk +from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.utils import is_hip, set_weight_attrs + +logger = logging.getLogger(__name__) + + +class GroupedGemmRunner(torch.nn.Module): + flashinfer_gemm_warpper = None + + def __init__(self, device, use_flashinfer: bool = False): + super().__init__() + self.device = device + self.use_flashinfer = use_flashinfer + if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None: + GroupedGemmRunner._init_flashinfer_wrapper(device) + + @classmethod + def _init_flashinfer_wrapper(cls, device): + from flashinfer import SegmentGEMMWrapper + + workspace_buffer = torch.empty( + 128 * 1024 * 1024, dtype=torch.int8, device=device + ) + cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer) + + # c = a * b + def forward( + self, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, + ): + if self.use_flashinfer: + # TODO: flashinfer + assert False + assert GroupedGemmRunner.flashinfer_gemm_warpper is not None + c = GroupedGemmRunner.flashinfer_gemm_warpper.run( + x=a, + weights=b, + batch_size=batch_size, + weight_column_major=weight_column_major, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + ) + else: + assert weight_column_major == True + c = grouped_gemm_triton( + a, + b, + c, + batch_size, + weight_column_major, + seg_indptr, + weight_indices, + use_fp8_w8a8, + scale_a, + scale_b, + ) + return c + + +class EPMoE(torch.nn.Module): + """ + MoE Expert Parallel Impl + + + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + self.tp_rank = get_tensor_model_parallel_rank() + + self.num_experts = num_experts + assert self.num_experts % self.tp_size == 0 + self.num_experts_per_partition = self.num_experts // self.tp_size + self.start_expert_id = self.tp_rank * self.num_experts_per_partition + self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 + + self.top_k = top_k + self.intermediate_size = intermediate_size + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() + self.use_fp8_w8a8 = False + self.activation_scheme = None + else: + self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( + quant_config + ) + self.use_fp8_w8a8 = True + self.fp8_dtype = torch.float8_e4m3fn + self.activation_scheme = quant_config.activation_scheme + + self.quant_method.create_weights( + layer=self, + num_experts_per_partition=self.num_experts_per_partition, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + self.grouped_gemm_runner = None + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + + if self.grouped_gemm_runner is None: + self.grouped_gemm_runner = GroupedGemmRunner( + hidden_states.device, use_flashinfer=False # TODO: use flashinfer + ) + + topk_weights, topk_ids = self.select_experts( + hidden_states, + router_logits, + self.top_k, + self.renormalize, + self.topk_group, + self.num_expert_group, + ) + + reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( + topk_ids, self.num_experts + ) + + gateup_input = torch.empty( + (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), + device=hidden_states.device, + dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + ) + if self.activation_scheme == "dynamic": + max_value = ( + torch.max(hidden_states) + .repeat(self.num_experts_per_partition) + .to(torch.float32) + ) + self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + + # PreReorder + pre_reorder_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + src2dst, + topk_ids, + self.w13_input_scale, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + weight_indices_cur_rank = torch.arange( + 0, + self.num_experts_per_partition, + device=hidden_states.device, + dtype=torch.int64, + ) + # GroupGemm-0 + gateup_output = torch.empty( + gateup_input.shape[0], + self.w13_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + gateup_output = self.grouped_gemm_runner( + a=gateup_input, + b=self.w13_weight, + c=gateup_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=self.use_fp8_w8a8, + scale_a=self.w13_input_scale, + scale_b=self.w13_weight_scale, + ) + + # Act + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + ) + if self.w2_input_scale is None: + self.w2_input_scale = torch.ones( + self.num_experts_per_partition, + dtype=torch.float32, + device=hidden_states.device, + ) + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + self.start_expert_id, + self.end_expert_id, + BLOCK_SIZE=512, + ) + + # GroupGemm-1 + down_output = torch.empty( + down_input.shape[0], + self.w2_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + down_output = self.grouped_gemm_runner( + a=down_input, + b=self.w2_weight, + c=down_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=self.use_fp8_w8a8, + scale_a=self.w2_input_scale, + scale_b=self.w2_weight_scale, + ) + + # PostReorder + output = torch.empty_like(hidden_states) + post_reorder_triton_kernel[(hidden_states.size(0),)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states.size(1), + BLOCK_SIZE=512, + ) + return output + + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + ): + if self.use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids.to(torch.int32) + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "experts.w2_" + ), + f"experts.{expert_id}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + if expert_id < self.start_expert_id or expert_id > self.end_expert_id: + return + expert_id = expert_id - self.start_expert_id + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError( + f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." + ) + + # Special case for fp8 scales. + if "scale" in weight_name: + self._load_fp8_scale( + param.data, loaded_weight, weight_name, shard_id, expert_id + ) + return + + expert_data = param.data[expert_id] + if shard_id == "w2": + param.data[expert_id] = loaded_weight + elif shard_id == "w1": + param.data[expert_id][: self.intermediate_size, :] = loaded_weight + elif shard_id == "w3": + param.data[expert_id][self.intermediate_size :, :] = loaded_weight + else: + raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") + + def _load_fp8_scale( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if ( + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}" + ) + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight + + +@register_custom_op("sglang_unquantized_ep_moe") +class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + def create_weights( + self, + layer: torch.nn.Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) + w13_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class Fp8EPMoEMethod(Fp8MoEMethod): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update({"quant_method": "tensor"}) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If rocm, use float8_e4m3fnuz as dtype + fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts_per_partition, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) + + for expert in range(layer.num_experts_per_partition): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 28677efeac4..5855d4248ff 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -58,6 +58,7 @@ "torchao_config": ServerArgs.torchao_config, "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, + "enable_ep_moe": ServerArgs.enable_ep_moe, } diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4eaedbccbff..3f0cbecac15 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -141,6 +141,7 @@ def __init__( "torchao_config": server_args.torchao_config, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, + "enable_ep_moe": server_args.enable_ep_moe, } ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 80db9a35c71..63cea92c289 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -31,6 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.ep_moe.layer import EPMoE from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -113,12 +114,12 @@ def __init__( "Only silu is supported for now." ) - self.experts = FusedMoE( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + self.experts = MoEImpl( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -834,7 +835,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index f1ae1f57a3d..f3fad226091 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -21,9 +21,13 @@ import torch from torch import nn from transformers import MixtralConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.layers.ep_moe.layer import EPMoE from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -38,6 +42,7 @@ ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -63,6 +68,7 @@ def __init__( prefix: str = "", ): super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. @@ -74,14 +80,13 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) - - self.experts = FusedMoE( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + self.experts = MoEImpl( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - reduce_results=True, renormalize=True, quant_config=quant_config, tp_size=tp_size, @@ -95,6 +100,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) @@ -319,7 +326,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7b337500fd7..8719d919068 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -93,6 +93,8 @@ class ServerArgs: # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + # Expert parallelism + ep_size: int = 1 # Multi-node distributed serving dist_init_addr: Optional[str] = None @@ -130,6 +132,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False + enable_ep_moe: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -216,6 +219,12 @@ def __post_init__(self): "Data parallel size is adjusted to be the same as tensor parallel size. " "Overlap scheduler is disabled." ) + # Expert parallelism + if self.enable_ep_moe: + self.ep_size = self.tp_size + logger.info( + f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) # GGUF if ( @@ -526,6 +535,14 @@ def add_cli_args(parser: argparse.ArgumentParser): "shortest_queue", ], ) + # Expert parallelism + parser.add_argument( + "--expert-parallel-size", + "--ep-size", + type=int, + default=ServerArgs.ep_size, + help="The expert parallelism size.", + ) # Multi-node distributed serving parser.add_argument( @@ -681,6 +698,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) + parser.add_argument( + "--enable-ep-moe", + action="store_true", + help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", @@ -760,6 +782,7 @@ def add_cli_args(parser: argparse.ArgumentParser): def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size + args.ep_size = args.expert_parallel_size attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py new file mode 100644 index 00000000000..4d9fd435edb --- /dev/null +++ b/test/srt/test_moe_ep.py @@ -0,0 +1,113 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEpMoE(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--ep-size", + "2", + "--enable-ep-moe", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +class TestEpMoEFP8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--ep-size", + "2", + "--enable-ep-moe", + "--quantization", + "fp8", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +if __name__ == "__main__": + unittest.main() From 84d96b3ae52ebf65baa6557647e09488b28eee3b Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 6 Dec 2024 15:42:10 +0800 Subject: [PATCH 26/60] Move FP8 to SGLang (#2370) Co-authored-by: HaiShaw --- .../srt/layers/quantization/__init__.py | 4 +- python/sglang/srt/layers/quantization/fp8.py | 559 ++++++++++++++++++ 2 files changed, 561 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/fp8.py diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index f34a581d657..3e2078c4a4d 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config -from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig @@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -100,13 +100,13 @@ def fp8_moe_apply( def fp8_get_quant_method(self, layer, prefix): """Enhanced get_quant_method for FP8 config.""" from vllm.model_executor.layers.linear import LinearBase - from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, ) from sglang.srt.layers.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.linear import UnquantizedLinearMethod + from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py new file mode 100644 index 00000000000..acdce0b8cbd --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -0,0 +1,559 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py + +import logging +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + apply_fp8_linear, + convert_to_channelwise, + cutlass_fp8_supported, + per_tensor_dequantize, + requantize_with_max_scale, +) +from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter + +from sglang.srt.layers.fused_moe_triton import ( + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) +from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.utils import ( + get_bool_env_var, + is_hip, + print_warning_once, + set_weight_attrs, +) + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = logging.getLogger(__name__) + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning( + "Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "fp8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") + # Disable marlin for ROCm + if is_hip(): + self.use_marlin = False + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) + + # INPUT ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + else: + layer.register_parameter("input_scale", None) + + def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + assert weight_scale.numel() == 1 + weight_scale = convert_to_channelwise( + weight_scale.expand(len(layer.logical_widths)), layer.logical_widths + ) + + # Update the layer with the new values. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + # If checkpoint is fp8, handle that there are N scales for N + # shards in a fused module + else: + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False + ) + if self.quant_config.activation_scheme == "static": + layer.input_scale = torch.nn.Parameter( + layer.input_scale.data, requires_grad=False + ) + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + weight = layer.weight + weight_scale = convert_to_channelwise( + layer.weight_scale, layer.logical_widths + ) + + # If using w8a8, torch._scaled_mm needs per tensor, so + # requantize the logical shards as a single weight. + else: + # Dequant -> Quant with max scale so we can run per tensor. + weight = layer.weight + weight_scale = layer.weight_scale + + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip(): + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=weight_scale, + input_scale=layer.input_scale, + ) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=layer.logical_widths, + ) + + # Update layer with new values. + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = Parameter( + layer.input_scale.max(), requires_grad=False + ) + + if self.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=False, + ) + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If ROCm, use float8_e4m3fnuz instead (MI300x HW) + fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, dtype=torch.float32, device=w13_weight.device + ), + requires_grad=False, + ) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + ) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + +class Fp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) From 34b364e07355f5216babd8c6fac7cb476f85e42c Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:13:04 +0800 Subject: [PATCH 27/60] optimize cuda graph max_bs_settings on low-end gpus (#2360) --- python/sglang/srt/server_args.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8719d919068..3a0a99af922 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -184,8 +184,12 @@ def __post_init__(self): # Set cuda graph max batch size if self.cuda_graph_max_bs is None: + # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. if gpu_mem is not None and gpu_mem < 25_000: - self.cuda_graph_max_bs = 8 + if self.tp_size < 4: + self.cuda_graph_max_bs = 8 + else: + self.cuda_graph_max_bs = 80 else: self.cuda_graph_max_bs = 160 From 37ee906f616efbd89b80fc2273e85bf8dbdd6682 Mon Sep 17 00:00:00 2001 From: Qun Yang Date: Fri, 6 Dec 2024 17:16:33 +0800 Subject: [PATCH 28/60] Add more support for intel Gaudi accelerators (#2357) --- .../runtime/engine/offline_batch_inference.py | 16 ++++-- python/sglang/srt/layers/sampler.py | 2 + python/sglang/srt/managers/scheduler.py | 6 +-- .../srt/managers/tp_worker_overlap_thread.py | 11 ++-- python/sglang/srt/mem_cache/memory_pool.py | 6 ++- python/sglang/srt/models/commandr.py | 4 +- python/sglang/srt/server_args.py | 7 +++ python/sglang/srt/utils.py | 50 +++++++++++++++++++ 8 files changed, 88 insertions(+), 14 deletions(-) diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index 7404c7e4e7f..724051eab53 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -1,7 +1,13 @@ +import argparse +import dataclasses + import sglang as sgl +from sglang.srt.server_args import ServerArgs -def main(): +def main( + server_args: ServerArgs, +): # Sample prompts. prompts = [ "Hello, my name is", @@ -13,7 +19,7 @@ def main(): sampling_params = {"temperature": 0.8, "top_p": 0.95} # Create an LLM. - llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + llm = sgl.Engine(**dataclasses.asdict(server_args)) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -25,4 +31,8 @@ def main(): # The __main__ condition is necessary here because we use "spawn" to create subprocesses # Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + main(server_args) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index d7db6036ca9..b0dfda3e882 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch( probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) sampled_index = torch.multinomial(probs_sort, num_samples=1) + # int32 range is enough to represent the token ids + probs_idx = probs_idx.to(torch.int32) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) return batch_next_token_ids diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3714f19b633..fd4edade92d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -993,7 +993,7 @@ def process_batch_result(self, batch: ScheduleBatch, result): self.process_batch_result_prefill(batch, result) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + torch.get_device_module(self.device).current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() def process_batch_result_prefill(self, batch: ScheduleBatch, result): @@ -1055,7 +1055,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + torch.get_device_module(self.device).current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model @@ -1130,7 +1130,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + torch.get_device_module(self.device).current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() self.stream_output(batch.reqs) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index e4e20ad8f64..6a453d2ad6d 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -32,12 +32,13 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_compiler_backend from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def resolve_future_token_ids(input_ids, future_token_ids_map): input_ids[:] = torch.where( input_ids < 0, @@ -73,7 +74,7 @@ def __init__( # Launch threads self.input_queue = Queue() self.output_queue = Queue() - self.forward_stream = torch.cuda.Stream() + self.forward_stream = torch.get_device_module(self.device).Stream() self.forward_thread = threading.Thread( target=self.forward_thread_func, ) @@ -97,7 +98,7 @@ def get_memory_pool(self): def forward_thread_func(self): try: - with torch.cuda.stream(self.forward_stream): + with torch.get_device_module(self.device).stream(self.forward_stream): self.forward_thread_func_() except Exception: traceback = get_exception_traceback() @@ -122,7 +123,7 @@ def forward_thread_func_(self): # Create event self.launch_done = threading.Event() - copy_done = torch.cuda.Event() + copy_done = torch.get_device_module(self.device).Event() # Resolve future tokens in the input input_ids = model_worker_batch.input_ids @@ -190,7 +191,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): ) # A cuda stream sync here to avoid the cuda illegal memory access error. - torch.cuda.current_stream().synchronize() + torch.get_device_module(self.device).current_stream().synchronize() # Push a new batch to the queue self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b028309c745..646e71749d8 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -27,6 +27,7 @@ import torch from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.utils import get_compiler_backend logger = logging.getLogger(__name__) @@ -129,6 +130,9 @@ def alloc(self, need_size: int): return select_index.to(self.device, non_blocking=True) def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + if self.is_not_in_free_group: self.free_slots = torch.concat((self.free_slots, free_index.cpu())) else: @@ -234,7 +238,7 @@ def set_kv_buffer( # This compiled version is slower in the unit test # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): dst_1[loc] = src_1.to(dtype).view(store_dtype) dst_2[loc] = src_2.to(dtype).view(store_dtype) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index a758e4f56b1..83ac3d8671b 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -62,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import get_compiler_backend, set_weight_attrs -@torch.compile +@torch.compile(backend=get_compiler_backend()) def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3a0a99af922..c2e75a642bd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -25,6 +25,7 @@ from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.utils import ( get_amdgpu_memory_capacity, + get_hpu_memory_capacity, get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, @@ -158,6 +159,8 @@ def __post_init__(self): gpu_mem = get_amdgpu_memory_capacity() elif torch.cuda.is_available(): gpu_mem = get_nvgpu_memory_capacity() + elif self.device == "hpu": + gpu_mem = get_hpu_memory_capacity() else: # GPU memory is not known yet or no GPU is available. gpu_mem = None @@ -194,6 +197,10 @@ def __post_init__(self): self.cuda_graph_max_bs = 160 # Choose kernel backends + if self.device == "hpu": + self.attention_backend = "torch_native" + self.sampling_backend = "pytorch" + if self.attention_backend is None: self.attention_backend = ( "flashinfer" if is_flashinfer_available() else "triton" diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 04372bac194..5c310136a21 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory free_gpu_memory = total_gpu_memory - used_memory + elif device == "hpu": + num_gpus = torch.hpu.device_count() + assert gpu_id < num_gpus + + if torch.hpu.current_device() != gpu_id: + print( + f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ", + "which may cause useless memory allocation for torch HPU context.", + ) + + free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info() + if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( torch.device(device, gpu_id) @@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity(): ) +def get_hpu_memory_capacity(): + try: + # Run hl-smi and capture the output + result = subprocess.run( + ["hl-smi --query | grep 'Total'"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"hl-smi error: {result.stderr.strip()}") + + # Parse the output to extract memory values in MiB + memory_values = [ + float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n") + ] + + if not memory_values: + raise ValueError("No GPU memory values found.") + + # Return the minimum memory value + return min(memory_values) + + except FileNotFoundError: + raise RuntimeError( + "hl-smi not found. Ensure Habana drivers are installed and accessible." + ) + + # Copy from pytorch and OpenRLHF to allow creating multiple main groups. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py # https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py @@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: return major, minor +def get_compiler_backend() -> str: + if hasattr(torch, "hpu") and torch.hpu.is_available(): + return "hpu_backend" + + return "inductor" + + sglang_lib = Library("sglang", "FRAGMENT") # noqa From 67b657945a1b62bafc0376cda78c91b1ef2a614a Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Fri, 6 Dec 2024 01:17:04 -0800 Subject: [PATCH 29/60] [router] support `/add_worker` api (#2369) --- rust/py_test/test_launch_server.py | 81 +++++++++++++++++++++++++++++- rust/src/router.rs | 49 ++++++++++++------ rust/src/server.rs | 22 +++++++- 3 files changed, 134 insertions(+), 18 deletions(-) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index a7a695aa9f6..dcfe423466d 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -1,3 +1,4 @@ +import socket import subprocess import time import unittest @@ -49,7 +50,7 @@ def popen_launch_router( # Use current environment env = None - process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() with requests.Session() as session: @@ -57,6 +58,52 @@ def popen_launch_router( try: response = session.get(f"{base_url}/health") if response.status_code == 200: + print(f"Router {base_url} is healthy") + return process + except requests.RequestException: + pass + time.sleep(10) + + raise TimeoutError("Router failed to start within the timeout period.") + + +def find_available_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def popen_launch_server( + model: str, + base_url: str, + timeout: float, +): + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + "--base-gpu-id", + "1", + ] + + process = subprocess.Popen(command, stdout=None, stderr=None) + + start_time = time.time() + with requests.Session() as session: + while time.time() - start_time < timeout: + try: + response = session.get(f"{base_url}/health") + if response.status_code == 200: + print(f"Server {base_url} is healthy") return process except requests.RequestException: pass @@ -76,10 +123,13 @@ def setUpClass(cls): dp_size=1, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, ) + cls.other_process = [] @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) + for process in cls.other_process: + kill_process_tree(process.pid) def test_mmlu(self): args = SimpleNamespace( @@ -98,6 +148,35 @@ def test_mmlu(self): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) + def test_add_worker(self): + # 1. start a worker, and wait until it is healthy + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + # 2. use /add_worker api to add it the the router + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + # 3. run mmlu + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + if __name__ == "__main__": unittest.main() diff --git a/rust/src/router.rs b/rust/src/router.rs index e17cba874c9..74e47209bd7 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -7,18 +7,18 @@ use log::{debug, info}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::thread; use std::time::Duration; #[derive(Debug)] pub enum Router { RoundRobin { - worker_urls: Vec, + worker_urls: Arc>>, current_index: AtomicUsize, }, Random { - worker_urls: Vec, + worker_urls: Arc>>, }, CacheAware { /* @@ -81,7 +81,7 @@ pub enum Router { Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted during the next eviction cycle. */ - worker_urls: Vec, + worker_urls: Arc>>, tree: Arc>, running_queue: Arc>>, processed_queue: Arc>>, @@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String { impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { match policy_config { - PolicyConfig::RandomConfig => Router::Random { worker_urls }, + PolicyConfig::RandomConfig => Router::Random { + worker_urls: Arc::new(RwLock::new(worker_urls)), + }, PolicyConfig::RoundRobinConfig => Router::RoundRobin { - worker_urls, + worker_urls: Arc::new(RwLock::new(worker_urls)), current_index: std::sync::atomic::AtomicUsize::new(0), }, PolicyConfig::CacheAwareConfig { @@ -183,7 +185,7 @@ impl Router { } Router::CacheAware { - worker_urls, + worker_urls: Arc::new(RwLock::new(worker_urls)), tree, running_queue, processed_queue, @@ -201,10 +203,10 @@ impl Router { Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } | Router::CacheAware { worker_urls, .. } => { - if worker_urls.is_empty() { + if worker_urls.read().unwrap().is_empty() { None } else { - Some(worker_urls[0].clone()) + Some(worker_urls.read().unwrap()[0].clone()) } } } @@ -228,15 +230,15 @@ impl Router { .fetch_update( std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst, - |x| Some((x + 1) % worker_urls.len()), + |x| Some((x + 1) % worker_urls.read().unwrap().len()), ) .unwrap(); - worker_urls[idx].clone() + worker_urls.read().unwrap()[idx].clone() } - Router::Random { worker_urls } => { - worker_urls[rand::random::() % worker_urls.len()].clone() - } + Router::Random { worker_urls } => worker_urls.read().unwrap() + [rand::random::() % worker_urls.read().unwrap().len()] + .clone(), Router::CacheAware { worker_urls, @@ -277,7 +279,7 @@ impl Router { .iter() .min_by_key(|(_url, &count)| count) .map(|(url, _)| url.clone()) - .unwrap_or_else(|| worker_urls[0].clone()) + .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) } else { // Use cache-aware routing when load is balanced let (matched_text, matched_worker) = tree.prefix_match(&text); @@ -333,7 +335,10 @@ impl Router { // For non-streaming requests, get response first let response = match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(_) => HttpResponse::InternalServerError().finish(), + Err(e) => { + let error_msg = format!("Failed to get response body: {}", e); + HttpResponse::InternalServerError().body(error_msg) + } }; // Then decrement running queue counter if using CacheAware @@ -379,4 +384,16 @@ impl Router { })) } } + + pub fn add_worker(&self, worker_url: String) { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + let mut urls = worker_urls.write().unwrap(); + info!("Added worker: {}", worker_url); + urls.push(worker_url); + } + } + } } diff --git a/rust/src/server.rs b/rust/src/server.rs index 3fbe5c3e895..269214acfef 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -1,9 +1,12 @@ use crate::router::PolicyConfig; use crate::router::Router; -use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; +use actix_web::{ + delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder, +}; use bytes::Bytes; use env_logger::Builder; use log::{info, LevelFilter}; +use std::collections::HashMap; use std::io::Write; #[derive(Debug)] @@ -128,6 +131,22 @@ async fn v1_completions( .await } +#[post("/add_worker")] +async fn add_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => { + return HttpResponse::BadRequest() + .body("Worker URL required. Provide 'url' query parameter") + } + }; + data.router.add_worker(worker_url); + HttpResponse::Ok().finish() +} + pub struct ServerConfig { pub host: String, pub port: u16, @@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(health) .service(health_generate) .service(get_server_info) + .service(add_worker) }) .bind((config.host, config.port))? .run() From f68175967cb61983377a634a25994c5c8e9fb7e0 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 6 Dec 2024 17:59:26 +0800 Subject: [PATCH 30/60] docs: update adoption (Meituan) (#2373) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 43c2f8c8808..bc8734936cd 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI. +The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI. ## Acknowledgment and Citation We learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). From f5b2a3aa67efb10918965b9f3555ff24ef971902 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 6 Dec 2024 02:01:23 -0800 Subject: [PATCH 31/60] Use proc.join instead of busy waiting (#2374) --- python/sglang/srt/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 7b91cb69797..29bc44eb524 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -462,8 +462,8 @@ def launch_engine( if server_args.node_rank >= 1: # For other nodes, they do not need to run tokenizer or detokenizer, # so they can just wait here. - while True: - pass + for proc in scheduler_procs: + proc.join() else: # Launch the data parallel controller reader, writer = mp.Pipe(duplex=False) From 3cde5eb62940556b4defbe285170658027fca353 Mon Sep 17 00:00:00 2001 From: vchzls Date: Fri, 6 Dec 2024 20:27:17 +0800 Subject: [PATCH 32/60] docs: Improve instructions for supporting new models (#2363) Co-authored-by: zhaohoulong --- docs/references/supported_models.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index dbf4f71a021..13572e437d5 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -80,3 +80,30 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla - Remove `Sample`. - Change `forward()` functions, and add `forward_batch`. - Add `EntryClass` at the end. + +### Registering an external model implementation + +In addition to the methods described above, you can also register your new model with the `ModelRegistry` before launching the server. This approach is useful if you want to integrate your model without needing to modify the source code. + +Here is how you can do it: + +```python +from sglang.srt.models.registry import ModelRegistry +from sglang.srt.server import launch_server + +# for a single model, you can add it to the registry +ModelRegistry.models[model_name] = model_class + +# for multiple models, you can imitate the import_model_classes() function in sglang/srt/models/registry.py +from functools import lru_cache + +@lru_cache() +def import_new_model_classes(): + model_arch_name_to_cls = {} + ... + return model_arch_name_to_cls + +ModelRegistry.models.update(import_new_model_classes()) + +launch_server(server_args) +``` \ No newline at end of file From 0e7409adb64ac19db2db3583ef3e4077cc569b30 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 6 Dec 2024 05:49:29 -0800 Subject: [PATCH 33/60] Fix the overlap for xgrammar (#2377) --- docs/references/supported_models.md | 2 +- .../srt/constrained/outlines_backend.py | 5 + .../srt/constrained/xgrammar_backend.py | 10 +- python/sglang/srt/managers/scheduler.py | 134 +++++++++--------- .../srt/managers/tp_worker_overlap_thread.py | 3 +- .../srt/sampling/sampling_batch_info.py | 17 +-- test/srt/test_json_constrained.py | 107 +++++++------- 7 files changed, 145 insertions(+), 133 deletions(-) diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 13572e437d5..bf1044f8498 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -106,4 +106,4 @@ def import_new_model_classes(): ModelRegistry.models.update(import_new_model_classes()) launch_server(server_args) -``` \ No newline at end of file +``` diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 26c476a0599..4820d473959 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -42,6 +42,7 @@ def __init__( self.guide = guide self.jump_forward_map = jump_forward_map self.state = 0 + self.finished = False def accept_token(self, token: int): self.state = self.guide.get_next_state(self.state, token) @@ -84,6 +85,10 @@ def allocate_vocab_mask( ) -> torch.Tensor: return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: tokens = torch.tensor( self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 1bcc51c6468..ee8e8eb07f4 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -45,6 +45,7 @@ def __init__( self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx + self.finished = False def accept_token(self, token: int): assert self.matcher.accept_token(token) @@ -85,12 +86,11 @@ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(vocab_mask, idx) @staticmethod - def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - if vocab_mask.device.type != logits.device.type: - # vocab_mask must then be on the same device as logits - # when applying the token bitmask, so we check and move if needed - vocab_mask = vocab_mask.to(logits.device) + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask.to(device, non_blocking=True) + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: apply_token_bitmask_inplace(logits, vocab_mask) def copy(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fd4edade92d..4ca4cd740dc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -114,9 +114,6 @@ def __init__( self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics - # Session info - self.sessions = {} - # Init inter-process communication context = zmq.Context(2) @@ -259,6 +256,10 @@ def __init__( self.num_generated_tokens = 0 self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval + self.current_stream = torch.get_device_module(self.device).current_stream() + + # Session info + self.sessions = {} # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size @@ -356,6 +357,7 @@ def __init__( ) def watchdog_thread(self): + """A watch dog thread that will try to kill the server itself if one batch takes too long.""" self.watchdog_last_forward_ct = 0 self.watchdog_last_time = time.time() @@ -433,61 +435,6 @@ def event_loop_overlap(self): self.last_batch = batch - def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): - # Check if other DP workers have running batches - if local_batch is None: - num_tokens = 0 - elif local_batch.forward_mode.is_decode(): - num_tokens = local_batch.batch_size() - else: - num_tokens = local_batch.extend_num_tokens - - local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) - global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) - torch.distributed.all_gather_into_tensor( - global_num_tokens, - local_num_tokens, - group=self.tp_cpu_group, - ) - - if local_batch is None and global_num_tokens.max().item() > 0: - local_batch = self.get_idle_batch() - - if local_batch is not None: - local_batch.global_num_tokens = global_num_tokens.tolist() - - # Check forward mode for cuda graph - if not self.server_args.disable_cuda_graph: - forward_mode_state = torch.tensor( - ( - 1 - if local_batch.forward_mode.is_decode() - or local_batch.forward_mode.is_idle() - else 0 - ), - dtype=torch.int32, - ) - torch.distributed.all_reduce( - forward_mode_state, - op=torch.distributed.ReduceOp.MIN, - group=self.tp_cpu_group, - ) - local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1 - - return local_batch - - def get_idle_batch(self): - idle_batch = ScheduleBatch.init_new( - [], - self.req_to_token_pool, - self.token_to_kv_pool, - self.tree_cache, - self.model_config, - self.enable_overlap, - ) - idle_batch.prepare_for_idle() - return idle_batch - def recv_requests(self): if self.tp_rank == 0 or self.server_args.enable_dp_attention: recv_reqs = [] @@ -993,7 +940,7 @@ def process_batch_result(self, batch: ScheduleBatch, result): self.process_batch_result_prefill(batch, result) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.get_device_module(self.device).current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() def process_batch_result_prefill(self, batch: ScheduleBatch, result): @@ -1049,13 +996,14 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): if req.grammar is not None: req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() else: # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.get_device_module(self.device).current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model @@ -1127,10 +1075,11 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): if req.grammar is not None: req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.get_device_module(self.device).current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() self.stream_output(batch.reqs) @@ -1328,6 +1277,61 @@ def stream_output(self, reqs: List[Req]): ) ) + def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): + # Check if other DP workers have running batches + if local_batch is None: + num_tokens = 0 + elif local_batch.forward_mode.is_decode(): + num_tokens = local_batch.batch_size() + else: + num_tokens = local_batch.extend_num_tokens + + local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) + global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) + torch.distributed.all_gather_into_tensor( + global_num_tokens, + local_num_tokens, + group=self.tp_cpu_group, + ) + + if local_batch is None and global_num_tokens.max().item() > 0: + local_batch = self.get_idle_batch() + + if local_batch is not None: + local_batch.global_num_tokens = global_num_tokens.tolist() + + # Check forward mode for cuda graph + if not self.server_args.disable_cuda_graph: + forward_mode_state = torch.tensor( + ( + 1 + if local_batch.forward_mode.is_decode() + or local_batch.forward_mode.is_idle() + else 0 + ), + dtype=torch.int32, + ) + torch.distributed.all_reduce( + forward_mode_state, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_cpu_group, + ) + local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1 + + return local_batch + + def get_idle_batch(self): + idle_batch = ScheduleBatch.init_new( + [], + self.req_to_token_pool, + self.token_to_kv_pool, + self.tree_cache, + self.model_config, + self.enable_overlap, + ) + idle_batch.prepare_for_idle() + return idle_batch + def move_ready_grammar_requests(self): """Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" num_ready_reqs = 0 @@ -1469,10 +1473,6 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): - # set cpu affinity to this gpu process - if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): - set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) - # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var if dp_rank is None and "SGLANG_DP_RANK" in os.environ: dp_rank = int(os.environ["SGLANG_DP_RANK"]) @@ -1482,6 +1482,10 @@ def run_scheduler_process( else: configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") + # set cpu affinity to this gpu process + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + suppress_other_loggers() parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 6a453d2ad6d..a9db1878391 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -80,6 +80,7 @@ def __init__( ) self.forward_thread.start() self.parent_process = psutil.Process().parent() + self.scheduler_stream = torch.get_device_module(self.device).current_stream() def get_worker_info(self): return self.worker.get_worker_info() @@ -191,7 +192,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): ) # A cuda stream sync here to avoid the cuda illegal memory access error. - torch.get_device_module(self.device).current_stream().synchronize() + self.scheduler_stream.synchronize() # Push a new batch to the queue self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 1624fd255f9..a64a84a62dc 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -158,22 +158,23 @@ def update_regex_vocab_mask(self): return # find a grammar from the list - grammar = next(grammar for grammar in self.grammars if grammar) + first_grammar = next(grammar for grammar in self.grammars if grammar) # maybe we can reuse the existing mask? - self.vocab_mask = grammar.allocate_vocab_mask( + self.vocab_mask = first_grammar.allocate_vocab_mask( vocab_size=self.vocab_size, batch_size=len(self.temperatures), device=self.device, ) - self.apply_mask = type(grammar).apply_vocab_mask # force to use static method + self.apply_mask = first_grammar.apply_vocab_mask # force to use static method + # Apply the mask for i, grammar in enumerate(self.grammars): - if grammar is not None: - try: - grammar.fill_vocab_mask(self.vocab_mask, i) - except RuntimeError: - continue + if grammar and not grammar.finished: + grammar.fill_vocab_mask(self.vocab_mask, i) + + # Move the mask to the device if needed + self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 28acdabd9d0..1a857d0da6e 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -1,5 +1,6 @@ """ -python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate +python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate +python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate """ import json @@ -11,38 +12,50 @@ from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) +def setup_class(cls, backend: str, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + backend, + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + class TestJSONConstrainedOutlinesBackend(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string", "pattern": "^[\\w]+$"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=300, - other_args=[ - "--max-running-requests", - "10", - "--grammar-backend", - "outlines", - ], - ) + setup_class(cls, backend="outlines", disable_overlap=False) + cls.check_jump_forward = False @classmethod def tearDownClass(cls): @@ -83,11 +96,13 @@ def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1) self.assertIsInstance(js_obj["population"], int) # Make sure jump forward is triggered - # NOTE: This is skipped because overlap scheduler does not support jump forward - # self.assertGreater( - # ret["meta_info"]["completion_tokens"], - # ret["meta_info"]["completion_tokens_wo_jump_forward"], - # ) + # NOTE: The overlap scheduler does not support jump forward so we only do this test + # when --disable-overlap-schedule is set. + if self.check_jump_forward: + self.assertGreater( + ret["meta_info"]["completion_tokens"], + ret["meta_info"]["completion_tokens_wo_jump_forward"], + ) def test_json_generate(self): self.run_decode(json_schema=self.json_schema) @@ -126,32 +141,18 @@ def test_mix_json_and_other(self): list(executor.map(self.run_decode, json_schemas)) +class TestJumpForwardOutlinesBackend(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, backend="outlines", disable_overlap=True) + cls.check_jump_forward = True + + class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=300, - other_args=[ - "--max-running-requests", - "10", - "--grammar-backend", - "xgrammar", - ], - ) + setup_class(cls, backend="xgrammar", disable_overlap=False) + cls.check_jump_forward = False if __name__ == "__main__": From e5f227c0ee9f491ed8a625733314e7218988e744 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 6 Dec 2024 06:08:19 -0800 Subject: [PATCH 34/60] Release v0.4.0.post1 (#2375) --- docker/Dockerfile.rocm | 2 +- docs/developer/setup_github_runner.md | 4 ++-- docs/start/install.md | 10 +++++----- python/pyproject.toml | 2 +- python/sglang/version.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index e51afce4d3b..2c9af6e7b0d 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,5 +1,5 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.0 -t v0.4.0-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.0.post1 -t v0.4.0.post1-rocm620 -f Dockerfile.rocm . # default base image ARG BASE_IMAGE="rocm/vllm-dev:20241022" diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index d9eeb626583..c82094f6de3 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0.post1-rocm620 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0.post1-rocm620 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/start/install.md b/docs/start/install.md index a5e5d73f561..e9d3abc8e78 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -13,7 +13,7 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.4.0 https://github.com/sgl-project/sglang.git +git clone -b v0.4.0.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -26,7 +26,7 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.4.0 https://github.com/sgl-project/sglang.git +git clone -b v0.4.0.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -51,7 +51,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.0 -t v0.4.0-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.0.post1 -t v0.4.0.post1-rocm620 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -60,11 +60,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.0-rocm620 \ + v0.4.0.post1-rocm620 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.0-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.0.post1-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/python/pyproject.toml b/python/pyproject.toml index 1452fad4ab7..7a19ac649c7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.0" +version = "0.4.0.post1" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" diff --git a/python/sglang/version.py b/python/sglang/version.py index 6a9beea82f6..a21caf9d324 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.0" +__version__ = "0.4.0.post1" From 499c85f1318d5ad914a599050bd3f616a28007e0 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Fri, 6 Dec 2024 11:26:07 -0800 Subject: [PATCH 35/60] [Router] remove duplicate char count (#2378) --- rust/py_test/test_launch_server.py | 2 + rust/src/server.rs | 4 +- rust/src/tree.rs | 84 +++++++++++++----------------- 3 files changed, 38 insertions(+), 52 deletions(-) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index dcfe423466d..0dacc2c9f7d 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -45,6 +45,8 @@ def popen_launch_router( port, "--dp", str(dp_size), # Convert dp_size to string + "--router-eviction-interval", + "5", # frequent eviction for testing ] # Use current environment diff --git a/rust/src/server.rs b/rust/src/server.rs index 269214acfef..7197b9a2709 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -1,8 +1,6 @@ use crate::router::PolicyConfig; use crate::router::Router; -use actix_web::{ - delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder, -}; +use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use bytes::Bytes; use env_logger::Builder; use log::{info, LevelFilter}; diff --git a/rust/src/tree.rs b/rust/src/tree.rs index 3c403676f01..1e39f02da38 100644 --- a/rust/src/tree.rs +++ b/rust/src/tree.rs @@ -24,7 +24,6 @@ struct Node { #[derive(Debug)] pub struct Tree { root: NodeRef, - // TODO: Char Count per tenant pub tenant_char_count: DashMap, } @@ -408,17 +407,9 @@ impl Tree { pub fn evict_tenant_data(&self, max_size: usize) { // Calculate used size and collect leaves let mut stack = vec![Arc::clone(&self.root)]; - let mut used_size_per_tenant: HashMap = HashMap::new(); let mut pq = BinaryHeap::new(); while let Some(curr) = stack.pop() { - for tenant in curr.tenant_last_access_time.iter() { - let size = used_size_per_tenant - .entry(tenant.key().clone()) - .or_insert(0); - *size += curr.text.read().unwrap().chars().count(); - } - for child in curr.children.iter() { stack.push(Arc::clone(child.value())); } @@ -436,64 +427,59 @@ impl Tree { } info!("Before eviction - Used size per tenant:"); - for (tenant, size) in &used_size_per_tenant { - info!("Tenant: {}, Size: {}", tenant, size); + for entry in self.tenant_char_count.iter() { + info!("Tenant: {}, Size: {}", entry.key(), entry.value()); } // Process eviction while let Some(Reverse(entry)) = pq.pop() { let EvictionEntry { tenant, node, .. } = entry; - if let Some(&used_size) = used_size_per_tenant.get(&tenant) { - if used_size <= max_size { + if let Some(used_size) = self.tenant_char_count.get(&tenant) { + if *used_size <= max_size { continue; } + } - // Update used size - if let Some(size) = used_size_per_tenant.get_mut(&tenant) { - *size -= node.text.read().unwrap().chars().count(); - } - - // Decrement when removing tenant from node - if node.tenant_last_access_time.contains_key(&tenant) { - self.tenant_char_count - .entry(tenant.clone()) - .and_modify(|count| { - if *count > 0 { - *count -= node.text.read().unwrap().chars().count(); - } - }); - } + // Decrement when removing tenant from node + if node.tenant_last_access_time.contains_key(&tenant) { + self.tenant_char_count + .entry(tenant.clone()) + .and_modify(|count| { + if *count > 0 { + *count -= node.text.read().unwrap().chars().count(); + } + }); + } - // Remove tenant from node - node.tenant_last_access_time.remove(&tenant); + // Remove tenant from node + node.tenant_last_access_time.remove(&tenant); - // Remove empty nodes - if node.children.is_empty() && node.tenant_last_access_time.is_empty() { - if let Some(parent) = node.parent.write().unwrap().as_ref() { - let first_char = node.text.read().unwrap().chars().next().unwrap(); - parent.children.remove(&first_char); - } + // Remove empty nodes + if node.children.is_empty() && node.tenant_last_access_time.is_empty() { + if let Some(parent) = node.parent.write().unwrap().as_ref() { + let first_char = node.text.read().unwrap().chars().next().unwrap(); + parent.children.remove(&first_char); } + } - // Add parent to queue if it becomes a leaf - if let Some(parent) = node.parent.read().unwrap().as_ref() { - if Tree::leaf_of(parent).contains(&tenant) { - if let Some(timestamp) = parent.tenant_last_access_time.get(&tenant) { - pq.push(Reverse(EvictionEntry { - timestamp: *timestamp, - tenant: tenant.clone(), - node: Arc::clone(parent), - })); - } + // Add parent to queue if it becomes a leaf + if let Some(parent) = node.parent.read().unwrap().as_ref() { + if Tree::leaf_of(parent).contains(&tenant) { + if let Some(timestamp) = parent.tenant_last_access_time.get(&tenant) { + pq.push(Reverse(EvictionEntry { + timestamp: *timestamp, + tenant: tenant.clone(), + node: Arc::clone(parent), + })); } } - } + }; } info!("After eviction - Used size per tenant:"); - for (tenant, size) in &used_size_per_tenant { - info!("Tenant: {}, Size: {}", tenant, size); + for entry in self.tenant_char_count.iter() { + info!("Tenant: {}, Size: {}", entry.key(), entry.value()); } } From 1bf9e34745e8056f9043065f4c485b4aa4d3864b Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Fri, 6 Dec 2024 11:53:15 -0800 Subject: [PATCH 36/60] [router] add remove tenant method in the radix tree (#2379) --- rust/src/router.rs | 2 +- rust/src/tree.rs | 142 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 136 insertions(+), 8 deletions(-) diff --git a/rust/src/router.rs b/rust/src/router.rs index 74e47209bd7..2b6b8d52cff 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -168,7 +168,7 @@ impl Router { let locked_tree_clone = tree_clone.lock().unwrap(); // Run eviction - locked_tree_clone.evict_tenant_data(max_tree_size); + locked_tree_clone.evict_tenant_by_size(max_tree_size); // Print the process queue let locked_processed_queue = processed_queue_clone.lock().unwrap(); diff --git a/rust/src/tree.rs b/rust/src/tree.rs index 1e39f02da38..e8dc8b7a0da 100644 --- a/rust/src/tree.rs +++ b/rust/src/tree.rs @@ -5,6 +5,7 @@ use log::info; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::collections::HashMap; +use std::collections::VecDeque; use std::sync::Arc; use std::sync::RwLock; @@ -404,7 +405,7 @@ impl Tree { .collect() } - pub fn evict_tenant_data(&self, max_size: usize) { + pub fn evict_tenant_by_size(&self, max_size: usize) { // Calculate used size and collect leaves let mut stack = vec![Arc::clone(&self.root)]; let mut pq = BinaryHeap::new(); @@ -483,6 +484,46 @@ impl Tree { } } + pub fn remove_tenant(&self, tenant: &str) { + // 1. Find all the leaves for the tenant + let mut stack = vec![Arc::clone(&self.root)]; + let mut queue = VecDeque::new(); + + while let Some(curr) = stack.pop() { + for child in curr.children.iter() { + stack.push(Arc::clone(child.value())); + } + + if Tree::leaf_of(&curr).contains(&tenant.to_string()) { + queue.push_back(Arc::clone(&curr)); + } + } + + // 2. Start from the leaves and traverse up to the root, removing the tenant from each node + while let Some(curr) = queue.pop_front() { + // remove tenant from node + curr.tenant_last_access_time.remove(&tenant.to_string()); + + // remove empty nodes + if curr.children.is_empty() && curr.tenant_last_access_time.is_empty() { + if let Some(parent) = curr.parent.read().unwrap().as_ref() { + let first_char = curr.text.read().unwrap().chars().next().unwrap(); + parent.children.remove(&first_char); + } + } + + // add parent to queue if it becomes a leaf + if let Some(parent) = curr.parent.read().unwrap().as_ref() { + if Tree::leaf_of(parent).contains(&tenant.to_string()) { + queue.push_back(Arc::clone(&parent)); + } + } + } + + // 3. Remove the tenant from the tenant_char_count map + self.tenant_char_count.remove(&tenant.to_string()); + } + pub fn get_tenant_char_count(&self) -> HashMap { self.tenant_char_count .iter() @@ -673,7 +714,7 @@ mod tests { ); // Test eviction - tree.evict_tenant_data(3); // This should evict tenants with more than 3 chars + tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars let post_eviction_smallest = tree.get_smallest_tenant(); println!("Smallest tenant after eviction: {}", post_eviction_smallest); @@ -754,7 +795,7 @@ mod tests { ); // Phase 4: Eviction test - tree.evict_tenant_data(10); + tree.evict_tenant_by_size(10); let computed_sizes = tree.get_used_size_per_tenant(); let maintained_counts: HashMap = tree @@ -1132,7 +1173,7 @@ mod tests { assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10 // Evict - should remove "hello" from tenant2 as it's the oldest - tree.evict_tenant_data(max_size); + tree.evict_tenant_by_size(max_size); tree.pretty_print(); @@ -1168,7 +1209,7 @@ mod tests { } // Perform eviction - tree.evict_tenant_data(max_size); + tree.evict_tenant_by_size(max_size); // Check sizes after eviction let sizes_after = tree.get_used_size_per_tenant(); @@ -1200,7 +1241,7 @@ mod tests { let handle = thread::spawn(move || { while start_time.elapsed() < test_duration { // Run eviction - tree.evict_tenant_data(max_size); + tree.evict_tenant_by_size(max_size); // Sleep for 5 seconds thread::sleep(Duration::from_secs(5)); @@ -1245,7 +1286,7 @@ mod tests { } // final eviction - tree.evict_tenant_data(max_size); + tree.evict_tenant_by_size(max_size); // Final size check let final_sizes = tree.get_used_size_per_tenant(); @@ -1352,4 +1393,91 @@ mod tests { assert_eq!(tree.prefix_match_tenant("hello", "tenant3"), ""); // Non-existent tenant assert_eq!(tree.prefix_match_tenant("help", "tenant3"), ""); // Non-existent tenant } + + #[test] + fn test_simple_tenant_eviction() { + let tree = Tree::new(); + + // Insert data for multiple tenants + tree.insert("hello", "tenant1"); + tree.insert("world", "tenant1"); + tree.insert("hello", "tenant2"); + tree.insert("help", "tenant2"); + + // Verify initial state + let initial_sizes = tree.get_used_size_per_tenant(); + assert_eq!(initial_sizes.get("tenant1").unwrap(), &10); // "hello" + "world" + assert_eq!(initial_sizes.get("tenant2").unwrap(), &6); // "hello" + "p" + + // Evict tenant1 + tree.remove_tenant("tenant1"); + + // Verify after eviction + let final_sizes = tree.get_used_size_per_tenant(); + assert!( + !final_sizes.contains_key("tenant1"), + "tenant1 should be completely removed" + ); + assert_eq!( + final_sizes.get("tenant2").unwrap(), + &6, + "tenant2 should be unaffected" + ); + + // Verify tenant1's data is inaccessible + assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("world", "tenant1"), ""); + + // Verify tenant2's data is still accessible + assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); + assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "help"); + } + + #[test] + fn test_complex_tenant_eviction() { + let tree = Tree::new(); + + // Create a more complex tree structure with shared prefixes + tree.insert("apple", "tenant1"); + tree.insert("application", "tenant1"); + tree.insert("apple", "tenant2"); + tree.insert("appetite", "tenant2"); + tree.insert("banana", "tenant1"); + tree.insert("banana", "tenant2"); + tree.insert("ball", "tenant2"); + + // Verify initial state + let initial_sizes = tree.get_used_size_per_tenant(); + println!("Initial sizes: {:?}", initial_sizes); + tree.pretty_print(); + + // Evict tenant1 + tree.remove_tenant("tenant1"); + + // Verify final state + let final_sizes = tree.get_used_size_per_tenant(); + println!("Final sizes: {:?}", final_sizes); + tree.pretty_print(); + + // Verify tenant1 is completely removed + assert!( + !final_sizes.contains_key("tenant1"), + "tenant1 should be completely removed" + ); + + // Verify all tenant1's data is inaccessible + assert_eq!(tree.prefix_match_tenant("apple", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("application", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("banana", "tenant1"), ""); + + // Verify tenant2's data is intact + assert_eq!(tree.prefix_match_tenant("apple", "tenant2"), "apple"); + assert_eq!(tree.prefix_match_tenant("appetite", "tenant2"), "appetite"); + assert_eq!(tree.prefix_match_tenant("banana", "tenant2"), "banana"); + assert_eq!(tree.prefix_match_tenant("ball", "tenant2"), "ball"); + + // Verify the tree structure is still valid for tenant2 + let tenant2_size = final_sizes.get("tenant2").unwrap(); + assert_eq!(tenant2_size, &(5 + 5 + 6 + 2)); // "apple" + "etite" + "banana" + "ll" + } } From c36736c841f735aa3a03bfa0db52c9d603c5fb49 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Fri, 6 Dec 2024 17:16:03 -0800 Subject: [PATCH 37/60] [router] Add remove worker api (#2380) --- .github/workflows/pr-test-rust.yml | 2 +- rust/py_test/test_launch_server.py | 42 +++++++++++++++++++++++------- rust/src/router.rs | 19 ++++++++++++++ rust/src/server.rs | 14 ++++++++++ 4 files changed, 67 insertions(+), 10 deletions(-) diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 0df81b487b5..b9e8c5bcb6b 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -57,7 +57,7 @@ jobs: cd rust pip install setuptools-rust wheel build python3 -m build - pip install dist/*.whl + pip install --force-reinstall dist/*.whl - name: Run e2e test run: | cd rust/py_test diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index 0dacc2c9f7d..b3f82988354 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -114,17 +114,12 @@ def popen_launch_server( raise TimeoutError("Server failed to start within the timeout period.") -class TestEvalAccuracyMini(unittest.TestCase): +class TestLaunchServer(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_router( - cls.model, - cls.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - ) + cls.process = None cls.other_process = [] @classmethod @@ -134,6 +129,14 @@ def tearDownClass(cls): kill_process_tree(process.pid) def test_mmlu(self): + # DP size = 2 + TestLaunchServer.process = popen_launch_router( + self.model, + self.base_url, + dp_size=2, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) + args = SimpleNamespace( base_url=self.base_url, model=self.model, @@ -150,14 +153,21 @@ def test_mmlu(self): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) - def test_add_worker(self): + def test_add_and_remove_worker(self): + # DP size = 1 + TestLaunchServer.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) # 1. start a worker, and wait until it is healthy port = find_available_port() worker_url = f"http://127.0.0.1:{port}" worker_process = popen_launch_server( self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ) - self.other_process.append(worker_process) + TestLaunchServer.other_process.append(worker_process) # 2. use /add_worker api to add it the the router with requests.Session() as session: response = session.post(f"{self.base_url}/add_worker?url={worker_url}") @@ -179,6 +189,20 @@ def test_add_worker(self): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) + # 4. use /remove_worker api to remove it from the router + with requests.Session() as session: + response = session.post(f"{self.base_url}/remove_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # 5. run mmlu again + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + if __name__ == "__main__": unittest.main() diff --git a/rust/src/router.rs b/rust/src/router.rs index 2b6b8d52cff..5641fccbc74 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -396,4 +396,23 @@ impl Router { } } } + + pub fn remove_worker(&self, worker_url: String) { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + let mut urls = worker_urls.write().unwrap(); + let index = urls.iter().position(|url| url == &worker_url).unwrap(); + urls.remove(index); + info!("Removed worker: {}", worker_url); + } + } + + // if cache aware, remove the worker from the tree + if let Router::CacheAware { tree, .. } = self { + tree.lock().unwrap().remove_tenant(&worker_url); + info!("Removed worker from tree: {}", worker_url); + } + } } diff --git a/rust/src/server.rs b/rust/src/server.rs index 7197b9a2709..d8d2e38e945 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -145,6 +145,19 @@ async fn add_worker( HttpResponse::Ok().finish() } +#[post("/remove_worker")] +async fn remove_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => return HttpResponse::BadRequest().finish(), + }; + data.router.remove_worker(worker_url); + HttpResponse::Ok().finish() +} + pub struct ServerConfig { pub host: String, pub port: u16, @@ -201,6 +214,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(health_generate) .service(get_server_info) .service(add_worker) + .service(remove_worker) }) .bind((config.host, config.port))? .run() From d332aa3b0c0ac131df4724084fc167f852611503 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 7 Dec 2024 19:28:53 +0800 Subject: [PATCH 38/60] fix: resolve fp8 moe issue (#2387) --- .../srt/layers/quantization/__init__.py | 49 +------------------ python/sglang/srt/layers/quantization/fp8.py | 34 +++++++++---- 2 files changed, 27 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 3e2078c4a4d..48b733fdb78 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.layers.quantization.fp8 import Fp8Config QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] -def fp8_moe_apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, -) -> torch.Tensor: - """Enhanced apply method for FP8 MoE.""" - from sglang.srt.layers.fused_moe_triton import FusedMoE - from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts - - # Expert selection - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - ) - - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - def fp8_get_quant_method(self, layer, prefix): """Enhanced get_quant_method for FP8 config.""" from vllm.model_executor.layers.linear import LinearBase @@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix): from sglang.srt.layers.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.linear import UnquantizedLinearMethod - from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod + from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): @@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix): def apply_monkey_patches(): """Apply all monkey patches in one place.""" - setattr(Fp8MoEMethod, "apply", fp8_moe_apply) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index acdce0b8cbd..0e3c7abd924 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -24,11 +24,6 @@ ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter -from sglang.srt.layers.fused_moe_triton import ( - FusedMoE, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, -) from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -100,6 +95,8 @@ def get_quant_method( ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + from sglang.srt.layers.fused_moe_triton import FusedMoE + if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() @@ -306,7 +303,7 @@ def apply( ) -class Fp8MoEMethod(FusedMoEMethodBase): +class Fp8MoEMethod: """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config): + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): self.quant_config = quant_config def create_weights( @@ -331,6 +346,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn @@ -521,8 +537,8 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe import fused_experts + from sglang.srt.layers.fused_moe_triton import FusedMoE + from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, From aaac33fd8dbc5f11790298d9d1ef325da487f3e4 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 7 Dec 2024 21:09:16 +0800 Subject: [PATCH 39/60] fix: update xgrammar v0.1.6 (#2390) --- python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 7a19ac649c7..186405dd7a3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,7 +22,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", - "xgrammar>=0.1.4"] + "xgrammar>=0.1.6"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"] # HIP (Heterogeneous-computing Interface for Portability) for AMD From 95f93f493a60a4dfdb30aa3d24ba3fc3b8666d3e Mon Sep 17 00:00:00 2001 From: HAI Date: Sat, 7 Dec 2024 05:18:26 -0800 Subject: [PATCH 40/60] Fp8 MoE optimizations on AMD (#2388) --- .../srt/layers/fused_moe_triton/fused_moe.py | 85 ++++++++++++++----- python/sglang/srt/layers/quantization/fp8.py | 34 +++++++- 2 files changed, 97 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py index 4f92512b2d5..e6ce9cb4d39 100644 --- a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py @@ -16,6 +16,7 @@ from sglang.srt.utils import direct_register_custom_op, get_device_name logger = logging.getLogger(__name__) +padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @triton.jit @@ -58,6 +59,7 @@ def fused_moe_kernel( compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + even_Ks: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -143,12 +145,21 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -254,7 +265,9 @@ def invoke_fused_moe_kernel( assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + padded_size = 0 if use_fp8_w8a8: + padded_size = padding_size A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None elif use_int8_w8a16: @@ -268,6 +281,12 @@ def invoke_fused_moe_kernel( * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + K = B.shape[2] - padded_size + if K % config["BLOCK_SIZE_K"] == 0: + even_Ks = True + else: + even_Ks = False + fused_moe_kernel[grid]( A, B, @@ -279,7 +298,7 @@ def invoke_fused_moe_kernel( expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - padded_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -296,6 +315,7 @@ def invoke_fused_moe_kernel( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + even_Ks=even_Ks, **config, ) @@ -351,20 +371,39 @@ def get_default_config( dtype: Optional[str], is_marlin: bool, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8": config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4, } + if M <= E: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -645,8 +684,12 @@ def fused_experts_impl( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ): + padded_size = padding_size + if not use_fp8_w8a8: + padded_size = 0 + # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -668,7 +711,7 @@ def fused_experts_impl( get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - w2.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), topk_ids.shape[1], config_dtype, ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 0e3c7abd924..c5a254b547e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,9 +1,11 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging +import os from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops @@ -24,6 +26,7 @@ ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -420,7 +423,7 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: - # If checkpoint is fp16, quantize in place. + # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn @@ -444,6 +447,19 @@ def process_weights_after_loading(self, layer: Module) -> None: ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return # If checkpoint is fp8, we need to handle that the @@ -472,6 +488,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False ) + # If ROCm, normalize the weights and scales to e4m3fnuz if is_hip(): # Normalize the weights and scales @@ -523,6 +540,19 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False ) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return def apply( @@ -540,6 +570,7 @@ def apply( from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts + # Expert selection topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -551,6 +582,7 @@ def apply( custom_routing_function=custom_routing_function, ) + # Expert fusion with FP8 quantization return fused_experts( x, layer.w13_weight, From 75ae968959566da691a0bde8e6f96f463ce531b3 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 8 Dec 2024 04:21:00 +0800 Subject: [PATCH 41/60] minor: update killall script (#2391) --- python/sglang/utils.py | 2 +- scripts/killall_sglang.sh | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index c1bf62ef983..5689b097d23 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,4 +1,4 @@ -"""Common utilities.""" +"""Common utilities""" import base64 import gc diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index fcad493c59c..cb187f46a0f 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,5 +1,14 @@ -# Kill all SGLang processes and free the GPU memory. +#!/bin/bash -kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') -kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') -kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') +# Show current GPU status +nvidia-smi + +# Clean SGLang processes +kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') 2>/dev/null + +# Clean all GPU processes if any argument is provided +if [ $# -gt 0 ]; then + kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null +fi From ef995dae1e9e7cdc7cfa7d78a195a3943d7e3e6b Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sat, 7 Dec 2024 15:39:54 -0800 Subject: [PATCH 42/60] [router] Health check on worker before adding to the router (#2392) --- rust/Cargo.lock | 7 +-- rust/Cargo.toml | 1 + rust/py_test/test_launch_server.py | 28 +++++------- rust/src/router.rs | 71 ++++++++++++++++++++++++++---- rust/src/server.rs | 3 +- 5 files changed, 79 insertions(+), 31 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 37c2733fdc0..8e7f306589f 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "actix-codec" @@ -2219,6 +2219,7 @@ dependencies = [ "serde", "serde_json", "tokenizers", + "tokio", ] [[package]] @@ -2475,9 +2476,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.0" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" dependencies = [ "backtrace", "bytes", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 5ac77665bcc..d49af81cf56 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -29,6 +29,7 @@ http = "1.1.0" env_logger = "0.11.5" log = "0.4.22" chrono = "0.4.38" +tokio = "1.42.0" [profile.release] lto = "thin" diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index b3f82988354..68945d8fb52 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -20,6 +20,7 @@ def popen_launch_router( base_url: str, dp_size: int, timeout: float, + policy: str = "cache_aware", ): """ Launch the router server process. @@ -29,6 +30,7 @@ def popen_launch_router( base_url: Server base URL dp_size: Data parallel size timeout: Server launch timeout + policy: Router policy, one of "cache_aware", "round_robin", "random" """ _, host, port = base_url.split(":") host = host[2:] @@ -47,11 +49,10 @@ def popen_launch_router( str(dp_size), # Convert dp_size to string "--router-eviction-interval", "5", # frequent eviction for testing + "--router-policy", + policy, ] - # Use current environment - env = None - process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() @@ -99,19 +100,8 @@ def popen_launch_server( process = subprocess.Popen(command, stdout=None, stderr=None) - start_time = time.time() - with requests.Session() as session: - while time.time() - start_time < timeout: - try: - response = session.get(f"{base_url}/health") - if response.status_code == 200: - print(f"Server {base_url} is healthy") - return process - except requests.RequestException: - pass - time.sleep(10) - - raise TimeoutError("Server failed to start within the timeout period.") + # intentionally don't wait and defer the job to the router health check + return process class TestLaunchServer(unittest.TestCase): @@ -135,6 +125,7 @@ def test_mmlu(self): self.base_url, dp_size=2, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="cache_aware", ) args = SimpleNamespace( @@ -160,6 +151,7 @@ def test_add_and_remove_worker(self): self.base_url, dp_size=1, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", # use round robin to make sure every worker processes requests ) # 1. start a worker, and wait until it is healthy port = find_available_port() @@ -168,11 +160,13 @@ def test_add_and_remove_worker(self): self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ) TestLaunchServer.other_process.append(worker_process) - # 2. use /add_worker api to add it the the router + + # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy with requests.Session() as session: response = session.post(f"{self.base_url}/add_worker?url={worker_url}") print(f"status code: {response.status_code}, response: {response.text}") self.assertEqual(response.status_code, 200) + # 3. run mmlu args = SimpleNamespace( base_url=self.base_url, diff --git a/rust/src/router.rs b/rust/src/router.rs index 5641fccbc74..acba974972c 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -3,13 +3,14 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use bytes::Bytes; use futures_util::{StreamExt, TryStreamExt}; -use log::{debug, info}; +use log::{debug, info, warn}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; use std::sync::{Arc, Mutex, RwLock}; use std::thread; use std::time::Duration; +use tokio; #[derive(Debug)] pub enum Router { @@ -385,14 +386,66 @@ impl Router { } } - pub fn add_worker(&self, worker_url: String) { - match self { - Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } - | Router::CacheAware { worker_urls, .. } => { - let mut urls = worker_urls.write().unwrap(); - info!("Added worker: {}", worker_url); - urls.push(worker_url); + pub async fn add_worker(&self, worker_url: String) -> HttpResponse { + let interval_secs = 10; // check every 10 seconds + let timeout_secs = 300; // 5 minutes + + let start_time = std::time::Instant::now(); + let client = reqwest::Client::new(); + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + return HttpResponse::InternalServerError().body(format!( + "Timeout {}s waiting for worker {} to become healthy", + timeout_secs, worker_url + )); + } + + match client.get(&format!("{}/health", worker_url)).send().await { + Ok(res) => { + if res.status().is_success() { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + info!("Worker {} health check passed", worker_url); + let mut urls = worker_urls.write().unwrap(); + if urls.contains(&worker_url) { + return HttpResponse::BadRequest() + .body(format!("Worker {} already exists", worker_url)); + } + info!("Added worker: {}", worker_url); + urls.push(worker_url.clone()); + } + } + return HttpResponse::Ok() + .body(format!("Successfully added worker: {}", worker_url)); + } else { + info!( + "Worker {} health check failed with status: {}. The worker might still be starting up.", + worker_url, res.status() + ); + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") + { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + continue; + } + } + Err(e) => { + info!("Worker {} health check failed: {}", worker_url, e); + + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + continue; + } } } } diff --git a/rust/src/server.rs b/rust/src/server.rs index d8d2e38e945..d7ec6ebc6e5 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -141,8 +141,7 @@ async fn add_worker( .body("Worker URL required. Provide 'url' query parameter") } }; - data.router.add_worker(worker_url); - HttpResponse::Ok().finish() + data.router.add_worker(worker_url).await } #[post("/remove_worker")] From 63dfab1beada0c6800b6694bf28eb8eb85657615 Mon Sep 17 00:00:00 2001 From: "Sangchun Ha (Patrick)" Date: Sun, 8 Dec 2024 18:04:08 +0900 Subject: [PATCH 43/60] Fix shape error that occurred when loading lora weight of gemma2 model. (#2330) --- python/sglang/srt/models/gemma2.py | 34 ++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index dbca7268803..0c0e6155d35 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -355,6 +355,40 @@ def forward( input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "qkv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_attention_heads, + ) + elif module_name in ["o_proj"]: + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name in ["kv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_key_value_heads, + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) From 1f09e84b9a31a8fa98fee6cbb9c5d8409967e653 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 8 Dec 2024 01:06:15 -0800 Subject: [PATCH 44/60] nit: Remove busy waiting on scheduler (#2382) --- docs/references/contributor_guide.md | 4 ++++ python/pyproject.toml | 2 +- .../sglang/srt/managers/detokenizer_manager.py | 2 ++ python/sglang/srt/managers/scheduler.py | 17 ++++++++++++----- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/docs/references/contributor_guide.md b/docs/references/contributor_guide.md index a9b25163d12..550f267ab1a 100644 --- a/docs/references/contributor_guide.md +++ b/docs/references/contributor_guide.md @@ -1,5 +1,9 @@ # Contributor Guide +# Build SGLang + +See [Install SGLang, Method 2: From Source section](../start/install.md). + ## Format Your Code Use these commands to format your code and pass CI linting tests. diff --git a/python/pyproject.toml b/python/pyproject.toml index 186405dd7a3..8e935528e21 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -13,7 +13,7 @@ classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] -dependencies = ["requests", "tqdm", "numpy", "IPython"] +dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] [project.optional-dependencies] runtime_common = ["aiohttp", "decord", "fastapi", diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index e74ba5026c1..120e990da2a 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -20,6 +20,7 @@ from typing import List, Union import psutil +import setproctitle import zmq from sglang.srt.hf_transformers_utils import get_tokenizer @@ -194,6 +195,7 @@ def run_detokenizer_process( server_args: ServerArgs, port_args: PortArgs, ): + setproctitle.setproctitle("sglang::detokenizer") configure_logger(server_args) parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4ca4cd740dc..13e8dae2345 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -25,6 +25,7 @@ from typing import List, Optional import psutil +import setproctitle import torch import zmq @@ -439,12 +440,16 @@ def recv_requests(self): if self.tp_rank == 0 or self.server_args.enable_dp_attention: recv_reqs = [] - while True: - try: - recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) - except zmq.ZMQError: - break + if self.last_batch is None: + recv_req = self.recv_from_tokenizer.recv_pyobj() recv_reqs.append(recv_req) + else: + while True: + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + recv_reqs.append(recv_req) else: recv_reqs = None @@ -1473,6 +1478,8 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): + setproctitle.setproctitle("sglang::scheduler") + # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var if dp_rank is None and "SGLANG_DP_RANK" in os.environ: dp_rank = int(os.environ["SGLANG_DP_RANK"]) From 7dc66fcb40aa693a299bdcf17247f52cc9deeff0 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sun, 8 Dec 2024 17:17:37 +0800 Subject: [PATCH 45/60] Optimize Triton decoding kernel for long context (#2394) --- .../srt/layers/attention/triton_backend.py | 21 +- .../attention/triton_ops/decode_attention.py | 629 ++++++++---------- python/sglang/srt/server_args.py | 7 + test/srt/test_triton_attention_kernels.py | 31 +- 4 files changed, 328 insertions(+), 360 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 1b7c4c46d26..1ea193ae7c3 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -40,6 +40,9 @@ def __init__(self, model_runner: ModelRunner): else: self.reduce_dtype = torch.float16 + self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + self.forward_metadata = None self.cuda_graph_max_seq_len = model_runner.model_config.context_len @@ -53,10 +56,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - total_num_tokens = forward_batch.seq_lens_sum attn_logits = torch.empty( - (self.num_head, total_num_tokens), - dtype=self.reduce_dtype, + ( + forward_batch.batch_size, + self.num_head, + self.num_kv_splits, + self.v_head_dim + 1, + ), + dtype=torch.float32, device=self.device, ) @@ -75,11 +82,8 @@ def init_cuda_graph_state(self, max_bs: int): (max_bs,), dtype=torch.int32, device=self.device ) self.cuda_graph_attn_logits = torch.empty( - ( - self.num_head, - self.cuda_graph_max_total_num_tokens, - ), - dtype=self.reduce_dtype, + (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), + dtype=torch.float32, device="cuda", ) @@ -189,6 +193,7 @@ def forward_decode( forward_batch.seq_lens, attn_logits, max_seq_len, + self.num_kv_splits, layer.scaling, layer.logit_cap, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 56d38693f4f..9eeb98a2963 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -17,8 +17,8 @@ """ # Adapted from -# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py -# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py import triton import triton.language as tl @@ -37,10 +37,10 @@ def tanh(x): def _fwd_kernel_stage1( Q, K_Buffer, + V_Buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, @@ -48,152 +48,137 @@ def _fwd_kernel_stage1( stride_qh, stride_buf_kbs, stride_buf_kh, - att_stride_h, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, - SPLIT_K: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - split_k_id = tl.program_id(2) + split_kv_id = tl.program_id(2) - reduce_dtype = Att_Out.dtype.element_ty cur_kv_head = cur_head // kv_group_num offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - q = tl.load(Q + off_q).to(reduce_dtype) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) - split_k_start = kv_len_per_split * split_k_id - split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_k_start, split_k_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, - mask=offs_n < split_k_end, - other=0, - ) - offs_buf_k = ( - k_loc[:, None] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[None, :] - ) - k = tl.load( - K_Buffer + offs_buf_k, - mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk), - other=0.0, - ).to(reduce_dtype) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - - if logit_cap > 0: - att_value = logit_cap * tanh(att_value / logit_cap) + q = tl.load(Q + off_q, mask=mask_d, other=0.0) - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) - tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end) + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale -@triton.jit -def _fwd_kernel_stage2( - logits, - V_Buffer, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, - stride_buf_vbs, - stride_buf_vh, - stride_obs, - stride_oh, - stride_req_to_token_b, - kv_group_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - Lv: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) - cur_kv_head = cur_head // kv_group_num + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) - offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] - v_ptrs = V_Buffer + offs_buf_v + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (start_n + offs_n), - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0, + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv ) - qk = tl.load( - logits - + cur_head * stride_logic_h - + (cur_batch_start_loc + start_n + offs_n), - mask=start_n + offs_n < cur_batch_seq_len, - other=float("-inf"), + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), ) - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load( - v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv ) - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) - e_max = n_e_max - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_d < Lv)) + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) def _decode_att_m_fwd( q, k_buffer, + v_buffer, att_out, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ): - BLOCK = 32 - SPLIT_K = 8 + BLOCK = 64 + NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, SPLIT_K) + grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] if kv_group_num == 1: @@ -202,14 +187,15 @@ def _decode_att_m_fwd( num_warps = 2 BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) _fwd_kernel_stage1[grid]( q, k_buffer, + v_buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), @@ -217,56 +203,20 @@ def _decode_att_m_fwd( q.stride(1), k_buffer.stride(0), k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), att_out.stride(0), + att_out.stride(1), + att_out.stride(2), kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, - SPLIT_K=SPLIT_K, + NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, num_warps=num_warps, - num_stages=1, + num_stages=2, Lk=Lk, - ) - - -def _decode_softmax_reducev_fwd( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, -): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logits.shape[0] - grid = (batch, head, 1) - kv_group_num = logits.shape[0] // v_buffer.shape[1] - - num_warps = 1 - - Lv = v_buffer.shape[-1] - BLOCK_DMODEL = triton.next_power_of_2(Lv) - - _fwd_kernel_stage2[grid]( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - logits.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), - o.stride(0), - o.stride(1), - req_to_tokens.stride(0), - kv_group_num=kv_group_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, Lv=Lv, ) @@ -275,10 +225,10 @@ def _decode_softmax_reducev_fwd( def _fwd_grouped_kernel_stage1( Q, K_Buffer, + V_Buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, @@ -286,23 +236,27 @@ def _fwd_grouped_kernel_stage1( stride_qh, stride_buf_kbs, stride_buf_kh, - att_stride_h, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, - SPLIT_K: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - split_k_id = tl.program_id(2) - - reduce_dtype = Att_Out.dtype.element_ty + split_kv_id = tl.program_id(2) if BLOCK_H < kv_group_num: VALID_BLOCK_H: tl.constexpr = BLOCK_H @@ -313,171 +267,136 @@ def _fwd_grouped_kernel_stage1( mask_h = mask_h & (cur_head < q_head_num) offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] - q = tl.load( - Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0 - ).to(reduce_dtype) + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk off_qpe = ( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) - qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) - split_k_start = kv_len_per_split * split_k_id - split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_k_start, split_k_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, - mask=offs_n < split_k_end, - other=0, - ) - offs_buf_k = ( - k_loc[None, :] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[:, None] + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 ) - k = tl.load( - K_Buffer + offs_buf_k, - mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk), - other=0.0, - ).to(reduce_dtype) - qk = tl.dot(q, k) - if BLOCK_DPE > 0: - offs_buf_kpe = ( - k_loc[None, :] * stride_buf_kbs + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh - + offs_dpe[:, None] + + offs_d[:, None] ) - kpe = tl.load( - K_Buffer + offs_buf_kpe, - mask=offs_n[None, :] < split_k_end, + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, - ).to(reduce_dtype) - qk += tl.dot(qpe, kpe) - qk *= sm_scale - - if logit_cap > 0: - qk = logit_cap * tanh(qk / logit_cap) - - offs_o = cur_head[:, None] * att_stride_h + ( - cur_batch_in_all_start_index + offs_n[None, :] - ) - - tl.store( - Att_Out + offs_o, - qk, - mask=mask_h[:, None] & (offs_n[None, :] < split_k_end), - ) - - -@triton.jit -def _fwd_grouped_kernel_stage2( - logits, - V_Buffer, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, - stride_buf_vbs, - stride_buf_vh, - stride_obs, - stride_oh, - stride_req_to_token_b, - kv_group_num: tl.constexpr, - q_head_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_H: tl.constexpr, - Lv: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head_id = tl.program_id(1) - cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - - if BLOCK_H < kv_group_num: - VALID_BLOCK_H: tl.constexpr = BLOCK_H - else: - VALID_BLOCK_H: tl.constexpr = kv_group_num - cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) - mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H - mask_h = mask_h & (cur_head < q_head_num) + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) - offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] - v_ptrs = V_Buffer + offs_buf_v + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max - e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") - e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) - acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (start_n + offs_n), - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0, + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] ) - offs_qk = cur_head[:, None] * stride_logic_h + ( - cur_batch_start_loc + start_n + offs_n[None, :] + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), ) - qk = tl.load( - logits + offs_qk, - mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), - other=float("-inf"), + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv ) - n_e_max = tl.maximum(tl.max(qk, 1), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max[:, None]) - e_sum = e_sum * old_scale + tl.sum(p, 1) - v = tl.load( - v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, ) - p = p.to(v.dtype) - acc = acc * old_scale[:, None] + tl.dot(p, v) - e_max = n_e_max - - acc = acc / e_sum[:, None] - off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) def _decode_grouped_att_m_fwd( q, k_buffer, + v_buffer, att_out, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ): - BLOCK = 64 + BLOCK = 32 Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] if Lk == 576: BLOCK_DMODEL = 512 @@ -488,20 +407,19 @@ def _decode_grouped_att_m_fwd( else: BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) batch, head_num = B_req_idx.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] - BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) - SPLIT_K = 8 + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits grid = ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), - SPLIT_K, + NUM_KV_SPLITS, ) - num_warps = 4 - extra_kargs = {} if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html @@ -511,10 +429,10 @@ def _decode_grouped_att_m_fwd( _fwd_grouped_kernel_stage1[grid]( q, k_buffer, + v_buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), @@ -522,41 +440,88 @@ def _decode_grouped_att_m_fwd( q.stride(1), k_buffer.stride(0), k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), att_out.stride(0), + att_out.stride(1), + att_out.stride(2), kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, - SPLIT_K=SPLIT_K, + NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, - num_warps=num_warps, - num_stages=1, + num_warps=4, + num_stages=2, Lk=Lk, + Lv=Lv, **extra_kargs, ) -def _decode_grouped_softmax_reducev_fwd( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + O, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, ): - BLOCK = 128 - batch, head_num = b_seq_len.shape[0], logits.shape[0] - kv_group_num = logits.shape[0] // v_buffer.shape[1] - BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) - grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) - num_warps = 8 + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] - BLOCK_DMODEL = triton.next_power_of_2(Lv) + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits extra_kargs = {} if is_hip_: @@ -564,28 +529,20 @@ def _decode_grouped_softmax_reducev_fwd( # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} - _fwd_grouped_kernel_stage2[grid]( + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( logits, - v_buffer, o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, logits.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), + logits.stride(1), + logits.stride(2), o.stride(0), o.stride(1), - req_to_tokens.stride(0), - kv_group_num=kv_group_num, - q_head_num=head_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, Lv=Lv, - num_warps=num_warps, - num_stages=1, + num_warps=4, + num_stages=2, **extra_kargs, ) @@ -597,34 +554,27 @@ def decode_attention_fwd_normal( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): _decode_att_m_fwd( q, k_buffer, + v_buffer, attn_logits, req_to_token, b_req_idx, - b_start_loc, b_seq_len, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd( - attn_logits, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) def decode_attention_fwd_grouped( @@ -634,34 +584,27 @@ def decode_attention_fwd_grouped( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): _decode_grouped_att_m_fwd( q, k_buffer, + v_buffer, attn_logits, req_to_token, b_req_idx, - b_start_loc, b_seq_len, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) - _decode_grouped_softmax_reducev_fwd( - attn_logits, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) def decode_attention_fwd( @@ -675,9 +618,11 @@ def decode_attention_fwd( b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): + assert num_kv_splits == attn_logits.shape[2] kv_group_num = q.shape[1] // v_buffer.shape[1] if kv_group_num == 1: @@ -689,10 +634,10 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) @@ -705,10 +650,10 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c2e75a642bd..fe12d961d3a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -141,6 +141,7 @@ class ServerArgs: enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False + triton_attention_num_kv_splits: int = 8 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False @@ -753,6 +754,12 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", ) + parser.add_argument( + "--triton-attention-num-kv-splits", + type=int, + default=ServerArgs.triton_attention_num_kv_splits, + help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.", + ) parser.add_argument( "--num-continuous-decode-steps", type=int, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 44abfd61bd7..b7917345b5b 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -182,6 +182,7 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): seq_len = 10 # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) + num_kv_splits = 8 # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -199,8 +200,8 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( - (H_Q, total_tokens), - dtype=dtype, + (B, H_Q, num_kv_splits, D + 1), + dtype=torch.float32, device="cuda", ) @@ -215,6 +216,7 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): b_seq_len, attn_logits, seq_len, + num_kv_splits, sm_scale, ) @@ -235,9 +237,10 @@ def test_decode_attention(self): def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): dtype = torch.bfloat16 - seq_len = 10 # This represents the number of tokens already in the sequence + seq_len = 128 # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) + num_kv_splits = 8 # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -247,8 +250,8 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") # o will have the same shape as q - o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") - o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") @@ -256,8 +259,8 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( - (H_Q, total_tokens), - dtype=dtype, + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, device="cuda", ) @@ -268,13 +271,19 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, seq_len, + num_kv_splits, sm_scale, ) + attn_logits1 = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + device="cuda", + ) + decode_attention_fwd_grouped( q, k_buffer, @@ -282,21 +291,23 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): o_grouped, req_to_token, b_req_idx, - b_start_loc, b_seq_len, - attn_logits, + attn_logits1, seq_len, + num_kv_splits, sm_scale, ) cos_sim = torch.nn.functional.cosine_similarity( o.flatten(), o_grouped.flatten(), dim=0 ) + print(cos_sim.item()) self.assertTrue(cos_sim.item() > 0.99) self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) def test_grouped_decode_attention(self): configs = [ + (2, 16, 16, 64, 64), (2, 16, 1, 64, 64), (2, 64, 1, 13, 13), (2, 128, 1, 80, 80), From 96db0f666d850156555b721ace0e3a9464249f34 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 8 Dec 2024 01:56:26 -0800 Subject: [PATCH 46/60] Update killall_sglang.sh (#2397) --- scripts/killall_sglang.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index cb187f46a0f..3696a1c35f4 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -4,7 +4,7 @@ nvidia-smi # Clean SGLang processes -kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang::' | grep -v 'grep' | awk '{print $2}') 2>/dev/null kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') 2>/dev/null kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') 2>/dev/null From 61dec545b0446256b655d4a8aeccb50d3a341ee4 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sun, 8 Dec 2024 19:37:03 +0800 Subject: [PATCH 47/60] Remove unused vars in the triton backend (#2401) --- .../srt/layers/attention/triton_backend.py | 21 ++++--------------- .../attention/triton_ops/decode_attention.py | 20 +++++++++--------- test/srt/test_triton_attention_kernels.py | 6 ------ 3 files changed, 14 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 1ea193ae7c3..1a539ebd75c 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -35,11 +35,6 @@ def __init__(self, model_runner: ModelRunner): model_runner.model_config.num_attention_heads // model_runner.tp_size ) - if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): - self.reduce_dtype = torch.float32 - else: - self.reduce_dtype = torch.float16 - self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] @@ -53,9 +48,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" if forward_batch.forward_mode.is_decode(): - start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) - start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - attn_logits = torch.empty( ( forward_batch.batch_size, @@ -67,13 +59,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): device=self.device, ) - max_seq_len = torch.max(forward_batch.seq_lens).item() max_extend_len = None else: - start_loc = attn_logits = max_seq_len = None + attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len + self.forward_metadata = attn_logits, max_extend_len def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -96,9 +87,7 @@ def init_forward_metadata_capture_cuda_graph( ): # NOTE: encoder_lens expected to be zeros or None self.forward_metadata = ( - self.cuda_graph_start_loc, self.cuda_graph_attn_logits, - self.cuda_graph_max_seq_len, None, ) @@ -137,7 +126,7 @@ def forward_extend( layer, forward_batch.out_cache_loc, k, v ) - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + _, max_extend_len = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -175,7 +164,7 @@ def forward_decode( else: o = torch.empty_like(q) - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + attn_logits, _ = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -189,10 +178,8 @@ def forward_decode( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.req_to_token_pool.req_to_token, forward_batch.req_pool_indices, - start_loc, forward_batch.seq_lens, attn_logits, - max_seq_len, self.num_kv_splits, layer.scaling, layer.logit_cap, diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 9eeb98a2963..d2e856ca605 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -19,6 +19,9 @@ # Adapted from # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +import logging + import triton import triton.language as tl @@ -26,6 +29,13 @@ is_hip_ = is_hip() +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. +logger.warn( + "The following error message 'operation scheduled before its operands' can be ignored." +) + @triton.jit def tanh(x): @@ -166,7 +176,6 @@ def _decode_att_m_fwd( Req_to_tokens, B_req_idx, B_Seqlen, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -389,7 +398,6 @@ def _decode_grouped_att_m_fwd( Req_to_tokens, B_req_idx, B_Seqlen, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -556,7 +564,6 @@ def decode_attention_fwd_normal( b_req_idx, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap=0.0, @@ -569,7 +576,6 @@ def decode_attention_fwd_normal( req_to_token, b_req_idx, b_seq_len, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -586,7 +592,6 @@ def decode_attention_fwd_grouped( b_req_idx, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap=0.0, @@ -599,7 +604,6 @@ def decode_attention_fwd_grouped( req_to_token, b_req_idx, b_seq_len, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -614,10 +618,8 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap=0.0, @@ -636,7 +638,6 @@ def decode_attention_fwd( b_req_idx, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -652,7 +653,6 @@ def decode_attention_fwd( b_req_idx, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index b7917345b5b..048d27c5658 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -196,7 +196,6 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") - b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( @@ -212,10 +211,8 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - seq_len, num_kv_splits, sm_scale, ) @@ -255,7 +252,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") - b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( @@ -273,7 +269,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): b_req_idx, b_seq_len, attn_logits, - seq_len, num_kv_splits, sm_scale, ) @@ -293,7 +288,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): b_req_idx, b_seq_len, attn_logits1, - seq_len, num_kv_splits, sm_scale, ) From a2486eb58fa32661965bf66034625155e87cfc05 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 8 Dec 2024 03:55:27 -0800 Subject: [PATCH 48/60] Fix a bug with logprob streaming + chunked prefill (#2403) --- python/sglang/bench_serving.py | 9 ++++++++- python/sglang/srt/managers/scheduler.py | 27 ++++++++++++++----------- python/sglang/test/test_utils.py | 1 + 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 1a909caa812..96e8677bb60 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -321,6 +321,8 @@ async def async_request_sglang_generate( }, "stream": not args.disable_stream, "lora_path": request_func_input.lora_name, + "return_logprob": args.return_logprob, + "logprob_start_len": -1, **request_func_input.extra_request_body, } headers = {} @@ -911,7 +913,7 @@ async def limited_request_func(request_func_input, pbar): prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, - output_len=test_output_len, + output_len=min(test_output_len, 32), lora_name=lora_name, extra_request_body=extra_request_body, ) @@ -1413,6 +1415,11 @@ def set_ulimit(target_soft_limit=65535): action="store_true", help="Disable ignoring EOS.", ) + parser.add_argument( + "--return-logprob", + action="store_true", + help="Return logprob.", + ) parser.add_argument( "--extra-request-body", metavar='{"key1": "value1", "key2": "value2"}', diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 13e8dae2345..d98499d6278 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -440,16 +440,11 @@ def recv_requests(self): if self.tp_rank == 0 or self.server_args.enable_dp_attention: recv_reqs = [] - if self.last_batch is None: - recv_req = self.recv_from_tokenizer.recv_pyobj() - recv_reqs.append(recv_req) - else: - while True: - try: - recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) - except zmq.ZMQError: - break - recv_reqs.append(recv_req) + while True: + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break else: recv_reqs = None @@ -949,6 +944,7 @@ def process_batch_result(self, batch: ScheduleBatch, result): batch.next_batch_sampling_info.sampling_info_done.set() def process_batch_result_prefill(self, batch: ScheduleBatch, result): + skip_stream_req = None if self.is_generation: logits_output, next_token_ids, bid = result @@ -1005,6 +1001,10 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): else: # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 + # There is only at most one request being currently chunked. + # Because this request does not finish prefill, + # we don't want to stream the request currently being chunked. + skip_stream_req = req if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() @@ -1034,7 +1034,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 - self.stream_output(batch.reqs) + self.stream_output(batch.reqs, skip_stream_req) def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids, bid = result @@ -1179,7 +1179,7 @@ def add_logprob_return_values( return num_input_logprobs - def stream_output(self, reqs: List[Req]): + def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None): """Stream the output to detokenizer.""" output_rids = [] output_meta_info: List[dict] = [] @@ -1199,6 +1199,9 @@ def stream_output(self, reqs: List[Req]): is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 for req in reqs: + if req is skip_req: + continue + # TODO(lianmin): revisit this for overlap + retract + stream if req.finished() or ( req.stream and (is_stream_iter or len(req.output_ids) == 1) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index f97fc12355a..514bf31a68b 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -568,6 +568,7 @@ def run_bench_serving( disable_tqdm=False, disable_stream=disable_stream, disable_ignore_eos=False, + return_logprob=False, lora_name=None, extra_request_body=None, profile=None, From 6128f7cff5e61517f69fafa6aec148d8d40657cf Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 8 Dec 2024 20:07:30 +0800 Subject: [PATCH 49/60] fix: specify dtype with begin_forward aka plan (#2404) --- python/sglang/srt/layers/attention/flashinfer_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index f89bc2ccaa2..536358fbc94 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -678,6 +678,7 @@ def call_begin_forward( self.num_qo_heads, self.num_kv_heads, self.head_dim, + q_data_type=self.q_data_type, ) # cached part @@ -691,6 +692,7 @@ def call_begin_forward( self.num_kv_heads, self.head_dim, 1, + q_data_type=self.q_data_type, ) From cc858953a0b0f99e5b7cf07dcf3335a158097df5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 8 Dec 2024 04:08:04 -0800 Subject: [PATCH 50/60] Fix recv_requests (#2405) --- python/sglang/srt/managers/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d98499d6278..a3316503b9a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -445,6 +445,7 @@ def recv_requests(self): recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) except zmq.ZMQError: break + recv_reqs.append(recv_req) else: recv_reqs = None From 67470bbb28591cc2a82a4cda419cdf6664ce46d2 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 8 Dec 2024 20:55:04 +0800 Subject: [PATCH 51/60] minor: update correct measurement unit (#2406) --- test/srt/test_bench_serving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 34a7b6c9670..b882f12f9df 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -125,7 +125,7 @@ def test_online_latency_default(self): if is_in_ci(): write_github_step_summary( f"### test_online_latency_default\n" - f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} token/s\n' + f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' ) self.assertLess(res["median_e2e_latency_ms"], 12000) self.assertLess(res["median_ttft_ms"], 86) From 0f8eb15323ea8776a945d917517990ca7cbfbdcb Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 9 Dec 2024 02:29:55 +0800 Subject: [PATCH 52/60] feat: support custom task runner (#2407) --- .github/workflows/experiment-runner.yml | 26 ++ test/srt/configs/sharegpt_config.yaml | 7 + test/srt/experiment_runner.py | 359 ++++++++++++++++++++++++ 3 files changed, 392 insertions(+) create mode 100644 .github/workflows/experiment-runner.yml create mode 100644 test/srt/configs/sharegpt_config.yaml create mode 100644 test/srt/experiment_runner.py diff --git a/.github/workflows/experiment-runner.yml b/.github/workflows/experiment-runner.yml new file mode 100644 index 00000000000..9cac407df91 --- /dev/null +++ b/.github/workflows/experiment-runner.yml @@ -0,0 +1,26 @@ +name: Experiment Runner + +on: + workflow_dispatch: + +concurrency: + group: experiment-runner-${{ github.ref }} + cancel-in-progress: true + +jobs: + experiment-runner-1-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + bash scripts/ci_install_dependency.sh + + - name: Test experiment runner + timeout-minutes: 10 + run: | + cd test/srt + python3 experiment_runner.py --config configs/sharegpt_config.yaml diff --git a/test/srt/configs/sharegpt_config.yaml b/test/srt/configs/sharegpt_config.yaml new file mode 100644 index 00000000000..a80b96c8eae --- /dev/null +++ b/test/srt/configs/sharegpt_config.yaml @@ -0,0 +1,7 @@ +tasks: + - name: sglang-benchmark + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate 16 + - name: vllm-benchmark + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate 16 diff --git a/test/srt/experiment_runner.py b/test/srt/experiment_runner.py new file mode 100644 index 00000000000..c4966dc77ba --- /dev/null +++ b/test/srt/experiment_runner.py @@ -0,0 +1,359 @@ +import argparse +import logging +import os +import queue +import re +import subprocess +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import psutil +import requests +import yaml + + +@dataclass +class ServerConfig: + command: str + process_names: List[str] + default_port: int + + +@dataclass +class TaskConfig: + server_cmd: str + client_cmd: str + name: Optional[str] = None + server_type: Optional[str] = None + + +@dataclass +class TaskResult: + name: str + success: bool + output: str + runtime: float + timestamp: str + + +SERVER_DEFAULTS = { + "sglang": ServerConfig( + command="sglang.launch_server", + process_names=["sglang.launch_server"], + default_port=30000, + ), + "vllm": ServerConfig( + command="vllm.entrypoints.openai.api_server", + process_names=["vllm.entrypoints.openai.api_server"], + default_port=8000, + ), +} + + +def parse_key_info(output: str) -> str: + """Extract and format key information from the output""" + key_info = [] + + # Extract Args namespace + args_match = re.search(r"Namespace\(.*?\)", output, re.DOTALL) + if args_match: + key_info.append(args_match.group(0)) + + # Extract input/output token counts + token_matches = re.findall(r"#(Input|Output) tokens: \d+", output) + key_info.extend(token_matches) + + # Extract benchmark result section + result_match = re.search( + r"============ Serving Benchmark Result ============.*?={50,}", + output, + re.DOTALL, + ) + if result_match: + key_info.append(result_match.group(0)) + + return "\n\n".join(key_info) + + +def extract_port_from_command(cmd: str, server_type: str) -> int: + port_match = re.search(r"--port[= ](\d+)", cmd) + if port_match: + return int(port_match.group(1)) + return SERVER_DEFAULTS.get(server_type, ServerConfig("", [], 8000)).default_port + + +def detect_server_type(cmd: str) -> str: + for server_type, config in SERVER_DEFAULTS.items(): + if config.command in cmd: + return server_type + return "unknown" + + +def stream_output( + process: subprocess.Popen, prefix: str, logger: logging.Logger +) -> queue.Queue: + output_queue = queue.Queue() + + def stream_pipe(pipe, prefix): + for line in iter(pipe.readline, ""): + if prefix == "CLIENT": + output_queue.put(line.rstrip()) + logger.debug(f"{prefix} | {line.rstrip()}") + + stdout_thread = threading.Thread( + target=stream_pipe, args=(process.stdout, prefix), daemon=True + ) + stderr_thread = threading.Thread( + target=stream_pipe, args=(process.stderr, prefix), daemon=True + ) + + stdout_thread.start() + stderr_thread.start() + return output_queue, (stdout_thread, stderr_thread) + + +class ProcessManager: + def __init__(self): + self.server_process: Optional[subprocess.Popen] = None + self.client_process: Optional[subprocess.Popen] = None + self.logger = logging.getLogger(__name__) + + def start_process( + self, command: str, prefix: str + ) -> Tuple[subprocess.Popen, queue.Queue]: + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + output_queue, threads = stream_output(process, prefix, self.logger) + return process, output_queue, threads + + def kill_process_tree(self, process: subprocess.Popen): + try: + parent = psutil.Process(process.pid) + children = parent.children(recursive=True) + + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + parent.kill() + gone, alive = psutil.wait_procs(children + [parent], timeout=3) + + for p in alive: + try: + p.kill() + except psutil.NoSuchProcess: + pass + + except psutil.NoSuchProcess: + pass + + def cleanup(self, process_names: List[str]): + if self.client_process: + self.kill_process_tree(self.client_process) + self.client_process = None + + if self.server_process: + self.kill_process_tree(self.server_process) + self.server_process = None + + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + try: + cmdline = " ".join(proc.cmdline()) + if any(name in cmdline for name in process_names): + proc.kill() + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + +class ExperimentRunner: + def __init__(self): + self.process_manager = ProcessManager() + self.logger = logging.getLogger(__name__) + + def wait_for_server(self, port: int, timeout: int = 300) -> bool: + start_time = time.time() + + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/health") + if response.status_code == 200: + self.logger.debug(f"Server ready on port {port}") + return True + except requests.RequestException: + time.sleep(2) + return False + + def run_task(self, config: TaskConfig) -> TaskResult: + start_time = time.time() + client_output = [] + + try: + if not config.server_type: + config.server_type = detect_server_type(config.server_cmd) + + server_config = SERVER_DEFAULTS.get(config.server_type) + if not server_config: + raise ValueError(f"Unknown server type: {config.server_type}") + + port = extract_port_from_command(config.server_cmd, config.server_type) + + self.process_manager.cleanup(server_config.process_names) + + self.logger.debug(f"Starting server: {config.name}") + self.process_manager.server_process, _, server_threads = ( + self.process_manager.start_process(config.server_cmd, "SERVER") + ) + + if not self.wait_for_server(port): + raise TimeoutError("Server startup timeout") + + time.sleep(10) + + self.logger.debug("Starting client") + self.process_manager.client_process, output_queue, client_threads = ( + self.process_manager.start_process(config.client_cmd, "CLIENT") + ) + + returncode = self.process_manager.client_process.wait() + + while True: + try: + line = output_queue.get_nowait() + client_output.append(line) + except queue.Empty: + break + + if returncode != 0: + raise RuntimeError(f"Client failed with code {returncode}") + + # Parse and format the output + full_output = "\n".join(client_output) + formatted_output = parse_key_info(full_output) + + return TaskResult( + name=config.name, + success=True, + output=formatted_output, + runtime=time.time() - start_time, + timestamp=datetime.now().isoformat(), + ) + + except Exception as e: + return TaskResult( + name=config.name, + success=False, + output=str(e), + runtime=time.time() - start_time, + timestamp=datetime.now().isoformat(), + ) + + finally: + if config.server_type in SERVER_DEFAULTS: + self.process_manager.cleanup( + SERVER_DEFAULTS[config.server_type].process_names + ) + time.sleep(10) + + +def load_config(config_path: str) -> List[TaskConfig]: + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) + + configs = [] + for idx, entry in enumerate(config_data.get("tasks", [])): + if not isinstance(entry, dict): + raise ValueError(f"Invalid entry at index {idx}") + + config = TaskConfig( + server_cmd=entry.get("server_cmd"), + client_cmd=entry.get("client_cmd"), + name=entry.get("name", f"task-{idx+1}"), + server_type=entry.get("server_type"), + ) + + if not config.server_cmd or not config.client_cmd: + raise ValueError(f"Missing commands in {config.name}") + + configs.append(config) + + return configs + + +def setup_logging(debug: bool = False): + level = logging.DEBUG if debug else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(), logging.FileHandler("experiment.log")], + ) + + +def format_results(results: List[TaskResult]) -> str: + """Format experiment results in Markdown for GitHub step summary.""" + output = ["# Experiment Results\n"] + + for result in results: + output.append(f"## {result.name}") + output.append(f"**Status**: {'✅ Success' if result.success else '❌ Failed'}") + output.append(f"**Runtime**: {result.runtime:.2f} seconds") + output.append(f"**Timestamp**: {result.timestamp}") + output.append("\n**Output**:\n```") + output.append(result.output) + output.append("```\n") + + return "\n".join(output) + + +def write_in_github_step_summary(results: List[TaskResult]): + """Write formatted results to GitHub step summary.""" + if not os.environ.get("GITHUB_STEP_SUMMARY"): + logging.warning("GITHUB_STEP_SUMMARY environment variable not set") + return + + formatted_content = format_results(results) + with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: + f.write(formatted_content) + + +def main(): + parser = argparse.ArgumentParser(description="Experiment Runner") + parser.add_argument( + "--config", type=str, required=True, help="Path to YAML config file" + ) + parser.add_argument("--debug", action="store_true", help="Enable debug output") + args = parser.parse_args() + + setup_logging(args.debug) + logger = logging.getLogger(__name__) + results = [] + + try: + configs = load_config(args.config) + runner = ExperimentRunner() + + for config in configs: + logger.info(f"Running {config.name}") + result = runner.run_task(config) + results.append(result) + + write_in_github_step_summary(results) + except Exception as e: + logger.error(f"Error: {e}") + raise + + +if __name__ == "__main__": + main() From 74bc9184c3eb8fcd2135a665424d484a652fe50a Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 9 Dec 2024 03:21:35 +0800 Subject: [PATCH 53/60] minor: add random use case (#2408) --- .github/workflows/experiment-runner.yml | 8 ++++++-- test/srt/configs/random_config.yaml | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 test/srt/configs/random_config.yaml diff --git a/.github/workflows/experiment-runner.yml b/.github/workflows/experiment-runner.yml index 9cac407df91..5ccb8ad28ff 100644 --- a/.github/workflows/experiment-runner.yml +++ b/.github/workflows/experiment-runner.yml @@ -2,6 +2,10 @@ name: Experiment Runner on: workflow_dispatch: + inputs: + script: + description: "Experiment Runner Script" + default: "configs/sharegpt_config.yaml" concurrency: group: experiment-runner-${{ github.ref }} @@ -20,7 +24,7 @@ jobs: bash scripts/ci_install_dependency.sh - name: Test experiment runner - timeout-minutes: 10 + timeout-minutes: 120 run: | cd test/srt - python3 experiment_runner.py --config configs/sharegpt_config.yaml + python3 experiment_runner.py --config ${{ inputs.script }} diff --git a/test/srt/configs/random_config.yaml b/test/srt/configs/random_config.yaml new file mode 100644 index 00000000000..eae8c27f41c --- /dev/null +++ b/test/srt/configs/random_config.yaml @@ -0,0 +1,25 @@ +tasks: + - name: sglang-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: vllm-128-4 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: vllm-2000-100 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: vllm-4000-200 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 + - name: vllm-32000-100 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 From f62055b528c2cac6cebdb6303e00bb479d7d2402 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 9 Dec 2024 04:15:21 +0800 Subject: [PATCH 54/60] minor: add random flashinfer vs triton use case (#2409) --- .../random_flashinfer_vs_triton_config.yaml | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 test/srt/configs/random_flashinfer_vs_triton_config.yaml diff --git a/test/srt/configs/random_flashinfer_vs_triton_config.yaml b/test/srt/configs/random_flashinfer_vs_triton_config.yaml new file mode 100644 index 00000000000..7f4a386ddcf --- /dev/null +++ b/test/srt/configs/random_flashinfer_vs_triton_config.yaml @@ -0,0 +1,25 @@ +tasks: + - name: sglang-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-triton-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-triton-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-triton-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 + - name: sglang-triton-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 From a6ca736c8e35b308ecb9d8e21c53692ef5c7ac4f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 8 Dec 2024 12:27:13 -0800 Subject: [PATCH 55/60] Simplify stream_output (#2398) --- python/sglang/srt/layers/logits_processor.py | 94 +++--- .../srt/managers/detokenizer_manager.py | 51 ++- python/sglang/srt/managers/io_struct.py | 49 ++- python/sglang/srt/managers/schedule_batch.py | 46 ++- python/sglang/srt/managers/scheduler.py | 297 ++++++++++-------- .../sglang/srt/managers/tokenizer_manager.py | 155 +++++---- .../srt/model_executor/cuda_graph_runner.py | 9 +- python/sglang/test/test_utils.py | 4 +- test/srt/test_json_constrained.py | 9 - 9 files changed, 425 insertions(+), 289 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 274c4c311ec..915cb47d271 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -39,10 +39,12 @@ class LogitsProcessorOutput: # The logprobs of input tokens. shape: [#token, vocab_size] input_token_logprobs: torch.Tensor = None - # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - input_top_logprobs: List = None - # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - output_top_logprobs: List = None + # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] + input_top_logprobs_val: List = None + input_top_logprobs_idx: List = None + # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] + output_top_logprobs_val: List = None + output_top_logprobs_idx: List = None @dataclasses.dataclass @@ -125,12 +127,15 @@ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata indices = ret.indices.tolist() if logits_metadata.forward_mode.is_decode(): - output_top_logprobs = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] for i, k in enumerate(logits_metadata.top_logprobs_nums): - output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) - return None, output_top_logprobs + output_top_logprobs_val.append(values[i][:k]) + output_top_logprobs_idx.append(indices[i][:k]) + return None, None, output_top_logprobs_val, output_top_logprobs_idx else: - input_top_logprobs, output_top_logprobs = [], [] + input_top_logprobs_val, input_top_logprobs_idx = [], [] + output_top_logprobs_val, output_top_logprobs_idx = [], [] pt = 0 for k, pruned_len in zip( @@ -138,27 +143,36 @@ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata logits_metadata.extend_logprob_pruned_lens_cpu, ): if pruned_len <= 0: - input_top_logprobs.append([]) - output_top_logprobs.append([]) + input_top_logprobs_val.append([]) + input_top_logprobs_idx.append([]) + output_top_logprobs_val.append([]) + output_top_logprobs_idx.append([]) continue - input_top_logprobs.append( - [ - list(zip(values[pt + j][:k], indices[pt + j][:k])) - for j in range(pruned_len - 1) - ] + input_top_logprobs_val.append( + [values[pt + j][:k] for j in range(pruned_len - 1)] ) - output_top_logprobs.append( + input_top_logprobs_idx.append( + [indices[pt + j][:k] for j in range(pruned_len - 1)] + ) + output_top_logprobs_val.append( + list( + values[pt + pruned_len - 1][:k], + ) + ) + output_top_logprobs_idx.append( list( - zip( - values[pt + pruned_len - 1][:k], - indices[pt + pruned_len - 1][:k], - ) + indices[pt + pruned_len - 1][:k], ) ) pt += pruned_len - return input_top_logprobs, output_top_logprobs + return ( + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + ) def forward( self, @@ -193,29 +207,22 @@ def forward( if not logits_metadata.return_logprob: return LogitsProcessorOutput( next_token_logits=last_logits, - next_token_logprobs=None, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=None, ) else: last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) if logits_metadata.forward_mode.is_decode(): if logits_metadata.return_top_logprob: - output_top_logprobs = self.get_top_logprobs( - last_logprobs, logits_metadata - )[1] + output_top_logprobs_val, output_top_logprobs_idx = ( + self.get_top_logprobs(last_logprobs, logits_metadata)[2:4] + ) else: - output_top_logprobs = None + output_top_logprobs_val = output_top_logprobs_idx = None return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=output_top_logprobs, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, ) else: # Slice the requested tokens to compute logprob @@ -246,11 +253,16 @@ def forward( # Get the logprob of top-k tokens if logits_metadata.return_top_logprob: - input_top_logprobs, output_top_logprobs = self.get_top_logprobs( - all_logprobs, logits_metadata - ) + ( + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + ) = self.get_top_logprobs(all_logprobs, logits_metadata) else: - input_top_logprobs = output_top_logprobs = None + input_top_logprobs_val = input_top_logprobs_idx = ( + output_top_logprobs_val + ) = output_top_logprobs_idx = None # Compute the normalized logprobs for the requested tokens. # Note that we pad a zero at the end for easy batching. @@ -273,8 +285,10 @@ def forward( next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, input_token_logprobs=input_token_logprobs, - input_top_logprobs=input_top_logprobs, - output_top_logprobs=output_top_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, ) def _get_logits( diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 120e990da2a..bc9e4a53b5c 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -17,7 +17,7 @@ import logging import signal from collections import OrderedDict -from typing import List, Union +from typing import Dict, List, Union import psutil import setproctitle @@ -76,17 +76,25 @@ def __init__( self.decode_status = LimitedCapacityDict() - def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim): - if no_stop_trim: + def trim_matched_stop( + self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool + ): + if no_stop_trim or not finished_reason: + return output + + matched = finished_reason.get("matched", None) + if not matched: return output - # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit - if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str): - pos = output.find(finished_reason.matched) + # TODO(lmzheng): handle the case where multiple stop strs are hit + + # Trim stop str. + if isinstance(matched, str) and isinstance(output, str): + pos = output.find(matched) return output[:pos] if pos != -1 else output - if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance( - output, list - ): + + # Trim stop token. + if isinstance(matched, int) and isinstance(output, list): assert len(output) > 0 return output[:-1] return output @@ -125,9 +133,9 @@ def event_loop(self): s.decode_ids = recv_obj.decode_ids[i] read_ids.append( - self.trim_eos( + self.trim_matched_stop( s.decode_ids[s.surr_offset :], - recv_obj.finished_reason[i], + recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) ) @@ -150,7 +158,7 @@ def event_loop(self): for i in range(bs): s = self.decode_status[recv_obj.rids[i]] new_text = read_texts[i][len(surr_texts[i]) :] - if recv_obj.finished_reason[i] is None: + if recv_obj.finished_reasons[i] is None: # Streaming chunk: update the decode status if len(new_text) > 0 and not new_text.endswith("�"): s.decoded_text = s.decoded_text + new_text @@ -161,9 +169,9 @@ def event_loop(self): new_text = find_printable_text(new_text) output_strs.append( - self.trim_eos( + self.trim_matched_stop( s.decoded_text + new_text, - recv_obj.finished_reason[i], + recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) ) @@ -171,9 +179,20 @@ def event_loop(self): self.send_to_tokenizer.send_pyobj( BatchStrOut( rids=recv_obj.rids, + finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, - meta_info=recv_obj.meta_info, - finished_reason=recv_obj.finished_reason, + prompt_tokens=recv_obj.prompt_tokens, + completion_tokens=recv_obj.completion_tokens, + cached_tokens=recv_obj.cached_tokens, + input_token_logprobs_val=recv_obj.input_token_logprobs_val, + input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, + output_token_logprobs_val=recv_obj.output_token_logprobs_val, + output_token_logprobs_idx=recv_obj.output_token_logprobs_idx, + input_top_logprobs_val=recv_obj.input_top_logprobs_val, + input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, + output_top_logprobs_val=recv_obj.output_top_logprobs_val, + output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + normalized_prompt_logprob=recv_obj.normalized_prompt_logprob, ) ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 27bf5a4bdb1..c5884b5f0f6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput: class BatchTokenIDOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[BaseFinishReason] + # For incremental decoding # The version id to sync decode status with in detokenizer_manager vids: List[int] decoded_texts: List[str] @@ -315,35 +318,61 @@ class BatchTokenIDOut: read_offsets: List[int] # Only used when `--skip-tokenizer-init` output_ids: Optional[List[int]] + # Detokenization configs skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] - meta_info: List[Dict] - finished_reason: List[BaseFinishReason] no_stop_trim: List[bool] + # Token counts + prompt_tokens: List[int] + completion_tokens: List[int] + cached_tokens: List[int] + # Logprobs + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + input_top_logprobs_val: List[List] + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] + normalized_prompt_logprob: List[float] @dataclass class BatchStrOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[dict] # The output decoded strings output_strs: List[str] - # The meta info - meta_info: List[Dict] - # The finish reason - finished_reason: List[BaseFinishReason] + + # Token counts + prompt_tokens: List[int] + completion_tokens: List[int] + cached_tokens: List[int] + # Logprobs + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + input_top_logprobs_val: List[List] + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] + normalized_prompt_logprob: List[float] @dataclass class BatchEmbeddingOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[BaseFinishReason] # The output embedding embeddings: List[List[float]] - # The meta info - meta_info: List[Dict] - # The finish reason - finished_reason: List[BaseFinishReason] + # Token counts + prompt_tokens: List[int] @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5855d4248ff..bb9eb181611 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -200,6 +200,9 @@ def __init__( origin_input_text: str, origin_input_ids: Tuple[int], sampling_params: SamplingParams, + return_logprob: bool = False, + top_logprobs_num: int = 0, + stream: bool = False, origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, @@ -217,10 +220,11 @@ def __init__( self.output_ids = [] # Each decode stage's output ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.session_id = session_id + self.input_embeds = input_embeds + # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path - self.input_embeds = input_embeds # Memory pool info self.req_pool_idx = None @@ -228,8 +232,8 @@ def __init__( # Check finish self.tokenizer = None self.finished_reason = None - self.stream = False self.to_abort = False + self.stream = stream # For incremental decoding # ----- | --------- read_ids -------| @@ -241,13 +245,9 @@ def __init__( # 2: read_offset # 3: last token self.vid = 0 # version id to sync decode status with in detokenizer_manager - self.decoded_text = "" self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None - - # The number of decoded tokens for token usage report. Note that - # this does not include the jump forward tokens. - self.completion_tokens_wo_jump_forward = 0 + self.decoded_text = "" # For multimodal inputs self.image_inputs: Optional[ImageInputs] = None @@ -256,22 +256,34 @@ def __init__( self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None + + # Chunked prefill self.is_being_chunked = 0 # For retraction self.is_retracted = False # Logprobs (arguments) - self.return_logprob = False + self.return_logprob = return_logprob self.logprob_start_len = 0 - self.top_logprobs_num = 0 + self.top_logprobs_num = top_logprobs_num # Logprobs (return value) self.normalized_prompt_logprob = None - self.input_token_logprobs = None - self.input_top_logprobs = None - self.output_token_logprobs = [] - self.output_top_logprobs = [] + self.input_token_logprobs_val = None + self.input_token_logprobs_idx = None + self.input_top_logprobs_val = None + self.input_top_logprobs_idx = None + + if return_logprob: + self.output_token_logprobs_val = [] + self.output_token_logprobs_idx = [] + self.output_top_logprobs_val = [] + self.output_top_logprobs_idx = [] + else: + self.output_token_logprobs_val = self.output_token_logprobs_idx = ( + self.output_top_logprobs_val + ) = self.output_top_logprobs_idx = None # Logprobs (internal values) # The tokens is prefilled but need to be considered as decode tokens @@ -295,8 +307,8 @@ def extend_image_inputs(self, image_inputs): else: self.image_inputs.merge(image_inputs) - # whether request reached finished condition def finished(self) -> bool: + # Whether request reached finished condition return self.finished_reason is not None def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): @@ -454,8 +466,10 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): k = k + 1 else: break - self.output_token_logprobs = self.output_token_logprobs[:k] - self.output_top_logprobs = self.output_top_logprobs[:k] + self.output_token_logprobs_val = self.output_token_logprobs_val[:k] + self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k] + self.output_top_logprobs_val = self.output_top_logprobs_val[:k] + self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k] self.logprob_start_len = prompt_tokens + k self.last_update_decode_tokens = len(self.output_ids) - k diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a3316503b9a..4ece8786878 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -515,6 +515,9 @@ def handle_generate_request( recv_req.input_text, recv_req.input_ids, recv_req.sampling_params, + return_logprob=recv_req.return_logprob, + top_logprobs_num=recv_req.top_logprobs_num, + stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, ) @@ -558,9 +561,6 @@ def handle_generate_request( return # Copy more attributes - req.return_logprob = recv_req.return_logprob - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream req.logprob_start_len = recv_req.logprob_start_len if req.logprob_start_len == -1: @@ -982,7 +982,6 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): continue if req.is_being_chunked <= 0: - req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -1035,7 +1034,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 - self.stream_output(batch.reqs, skip_stream_req) + self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids, bid = result @@ -1065,7 +1064,6 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue - req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -1073,11 +1071,15 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.tree_cache.cache_finished_req(req) if req.return_logprob: - req.output_token_logprobs.append( - (next_token_logprobs[i], next_token_id) - ) + req.output_token_logprobs_val.append(next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_id) if req.top_logprobs_num > 0: - req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + req.output_top_logprobs_val.append( + logits_output.output_top_logprobs_val[i] + ) + req.output_top_logprobs_idx.append( + logits_output.output_top_logprobs_idx[i] + ) if req.grammar is not None: req.grammar.accept_token(next_token_id) @@ -1088,7 +1090,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - self.stream_output(batch.reqs) + self.stream_output(batch.reqs, batch.return_logprob) self.token_to_kv_pool.free_group_end() @@ -1108,9 +1110,8 @@ def add_logprob_return_values( output: LogitsProcessorOutput, ): """Attach logprobs to the return values.""" - req.output_token_logprobs.append( - (output.next_token_logprobs[i], next_token_ids[i]) - ) + req.output_token_logprobs_val.append(output.next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_ids[i]) # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len @@ -1118,173 +1119,195 @@ def add_logprob_return_values( if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - if req.input_token_logprobs is None: - input_token_logprobs = output.input_token_logprobs[ + if req.input_token_logprobs_val is None: + input_token_logprobs_val = output.input_token_logprobs[ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens ] - input_token_ids = req.fill_ids[ + + input_token_logprobs_idx = req.fill_ids[ len(req.fill_ids) - num_input_logprobs + 1 : len(req.fill_ids) - req.last_update_decode_tokens ] - # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. - input_token_ids = [ + input_token_logprobs_idx = [ x if x < self.model_config.vocab_size - 1 else 0 - for x in input_token_ids + for x in input_token_logprobs_idx ] - req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids)) - if ( req.logprob_start_len == 0 ): # The first token does not have logprob, pad it. - req.input_token_logprobs = [ - (None, req.fill_ids[0]) - ] + req.input_token_logprobs + input_token_logprobs_val = [None] + input_token_logprobs_val + input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx + + req.input_token_logprobs_val = input_token_logprobs_val + req.input_token_logprobs_idx = input_token_logprobs_idx if req.last_update_decode_tokens != 0: # Some decode tokens are re-computed in an extend batch - req.output_token_logprobs.extend( - list( - zip( - output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 - ], - req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) - ], - ) - ) + req.output_token_logprobs_val.extend( + output.input_token_logprobs[ + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens : pt + + num_input_logprobs + - 1 + ], + ) + req.output_token_logprobs_idx.extend( + req.fill_ids[ + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) + ] ) if req.top_logprobs_num > 0: - if req.input_top_logprobs is None: - req.input_top_logprobs = output.input_top_logprobs[i] + if req.input_top_logprobs_val is None: + req.input_top_logprobs_val = output.input_top_logprobs_val[i] + req.input_top_logprobs_idx = output.input_top_logprobs_idx[i] if req.logprob_start_len == 0: - req.input_top_logprobs = [None] + req.input_top_logprobs + req.input_top_logprobs_val = [None] + req.input_top_logprobs_val + req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx if req.last_update_decode_tokens != 0: - req.output_top_logprobs.extend( - output.input_top_logprobs[i][-req.last_update_decode_tokens :] + req.output_top_logprobs_val.extend( + output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] ) - req.output_top_logprobs.append(output.output_top_logprobs[i]) + req.output_top_logprobs_idx.extend( + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] + ) + req.output_top_logprobs_val.append(output.output_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i]) return num_input_logprobs - def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None): + def stream_output( + self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None + ): """Stream the output to detokenizer.""" - output_rids = [] - output_meta_info: List[dict] = [] - output_finished_reason: List[BaseFinishReason] = [] + rids = [] + finished_reasons: List[BaseFinishReason] = [] + if self.is_generation: - output_vids = [] + vids = [] decoded_texts = [] - output_read_ids = [] - output_read_offsets = [] + decode_ids_list = [] + read_offsets = [] output_ids = [] - output_skip_special_tokens = [] - output_spaces_between_special_tokens = [] - output_no_stop_trim = [] - else: # embedding or reward model - output_embeddings = [] - - is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 - - for req in reqs: - if req is skip_req: - continue + skip_special_tokens = [] + spaces_between_special_tokens = [] + no_stop_trim = [] + prompt_tokens = [] + completion_tokens = [] + cached_tokens = [] + + if return_logprob: + input_token_logprobs_val = [] + input_token_logprobs_idx = [] + output_token_logprobs_val = [] + output_token_logprobs_idx = [] + input_top_logprobs_val = [] + input_top_logprobs_idx = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] + normalized_prompt_logprob = [] + else: + input_token_logprobs_val = input_token_logprobs_idx = ( + output_token_logprobs_val + ) = output_token_logprobs_idx = input_top_logprobs_val = ( + input_top_logprobs_idx + ) = output_top_logprobs_val = output_top_logprobs_idx = ( + normalized_prompt_logprob + ) = None + + for req in reqs: + if req is skip_req: + continue - # TODO(lianmin): revisit this for overlap + retract + stream - if req.finished() or ( - req.stream and (is_stream_iter or len(req.output_ids) == 1) - ): - output_rids.append(req.rid) - output_finished_reason.append(req.finished_reason) - if self.is_generation: - output_vids.append(req.vid) + # TODO(lianmin): revisit this for overlap + retract + stream + if ( + req.finished() + # If stream, follow the given stream_interval + or (req.stream and len(req.output_ids) % self.stream_interval == 0) + # If not stream, we still want to output some tokens to get the benefit of incremental decoding. + or (not req.stream and len(req.output_ids) % 50 == 0) + ): + rids.append(req.rid) + finished_reasons.append( + req.finished_reason.to_json() if req.finished_reason else None + ) + vids.append(req.vid) decoded_texts.append(req.decoded_text) - read_ids, read_offset = req.init_incremental_detokenize() - output_read_ids.append(read_ids) - output_read_offsets.append(read_offset) + decode_ids, read_offset = req.init_incremental_detokenize() + decode_ids_list.append(decode_ids) + read_offsets.append(read_offset) if self.skip_tokenizer_init: output_ids.append(req.output_ids) - output_skip_special_tokens.append( - req.sampling_params.skip_special_tokens - ) - output_spaces_between_special_tokens.append( + skip_special_tokens.append(req.sampling_params.skip_special_tokens) + spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens ) - output_no_stop_trim.append(req.sampling_params.no_stop_trim) - - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - "completion_tokens": len(req.output_ids), - "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "cached_tokens": req.cached_tokens, - "finish_reason": ( - req.finished_reason.to_json() - if req.finished_reason is not None - else None - ), - } - if req.return_logprob: - ( - meta_info["input_token_logprobs"], - meta_info["output_token_logprobs"], - meta_info["input_top_logprobs"], - meta_info["output_top_logprobs"], - meta_info["normalized_prompt_logprob"], - ) = ( - req.input_token_logprobs, - req.output_token_logprobs, - req.input_top_logprobs, - req.output_top_logprobs, - req.normalized_prompt_logprob, - ) - output_meta_info.append(meta_info) - else: # embedding or reward model - output_embeddings.append(req.embedding) - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - } - output_meta_info.append(meta_info) - - # Send to detokenizer - if output_rids: - if self.is_generation: + no_stop_trim.append(req.sampling_params.no_stop_trim) + + prompt_tokens.append(len(req.origin_input_ids)) + completion_tokens.append(len(req.output_ids)) + cached_tokens.append(req.cached_tokens) + + if return_logprob: + input_token_logprobs_val.append(req.input_token_logprobs_val) + input_token_logprobs_idx.append(req.input_token_logprobs_idx) + output_token_logprobs_val.append(req.output_token_logprobs_val) + output_token_logprobs_idx.append(req.output_token_logprobs_idx) + input_top_logprobs_val.append(req.input_top_logprobs_val) + input_top_logprobs_idx.append(req.input_top_logprobs_idx) + output_top_logprobs_val.append(req.output_top_logprobs_val) + output_top_logprobs_idx.append(req.output_top_logprobs_idx) + normalized_prompt_logprob.append(req.normalized_prompt_logprob) + + # Send to detokenizer + if rids: self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( - output_rids, - output_vids, + rids, + finished_reasons, + vids, decoded_texts, - output_read_ids, - output_read_offsets, + decode_ids_list, + read_offsets, output_ids, - output_skip_special_tokens, - output_spaces_between_special_tokens, - output_meta_info, - output_finished_reason, - output_no_stop_trim, - ) - ) - else: # embedding or reward model - self.send_to_detokenizer.send_pyobj( - BatchEmbeddingOut( - output_rids, - output_embeddings, - output_meta_info, - output_finished_reason, + skip_special_tokens, + spaces_between_special_tokens, + no_stop_trim, + prompt_tokens, + completion_tokens, + cached_tokens, + input_token_logprobs_val, + input_token_logprobs_idx, + output_token_logprobs_val, + output_token_logprobs_idx, + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + normalized_prompt_logprob, ) ) + else: # embedding or reward model + embeddings = [] + prompt_tokens = [] + for req in reqs: + assert req.finished() + rids.append(req.rid) + finished_reasons.append(req.finished_reason.to_json()) + embeddings.append(req.embedding) + prompt_tokens.append(len(req.origin_input_ids)) + self.send_to_detokenizer.send_pyobj( + BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens) + ) def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # Check if other DP workers have running batches diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 56e01528add..4788565ac01 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -22,7 +22,7 @@ import sys import time import uuid -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import fastapi import uvloop @@ -76,6 +76,7 @@ class ReqState: out_list: List finished: bool event: asyncio.Event + obj: Any # For metrics created_time: float @@ -283,7 +284,7 @@ async def _wait_one_response( ): """Wait for the response of one request.""" event = asyncio.Event() - state = ReqState([], False, event, created_time=created_time) + state = ReqState([], False, event, obj, created_time=created_time) self.rid_to_state[obj.rid] = state while True: @@ -295,15 +296,7 @@ async def _wait_one_response( raise ValueError(f"Abort request {obj.rid}") continue - if isinstance(obj, GenerateReqInput): - out = self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, - obj.return_text_in_logprobs, - ) - else: # isinstance(obj, (EmbeddingReqInput,)) - out = state.out_list[-1] + out = state.out_list[-1] state.out_list = [] if state.finished: @@ -315,7 +308,13 @@ async def _wait_one_response( break state.event.clear() - yield out + + if obj.stream: + yield out + else: + if request is not None and await request.is_disconnected(): + self.abort_request(obj.rid) + raise ValueError(f"Abort request {obj.rid}") async def _handle_batch_request( self, @@ -609,29 +608,55 @@ async def handle_loop(self): if state is None: continue - recv_obj.meta_info[i]["id"] = rid + meta_info = { + "id": rid, + "finish_reason": recv_obj.finished_reasons[i], + "prompt_tokens": recv_obj.prompt_tokens[i], + } + + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( + meta_info, + state.obj.top_logprobs_num, + state.obj.return_text_in_logprobs, + recv_obj, + i, + ) + if isinstance(recv_obj, BatchStrOut): out_dict = { "text": recv_obj.output_strs[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": { + **meta_info, + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + }, } elif isinstance(recv_obj, BatchTokenIDOut): out_dict = { "token_ids": recv_obj.output_ids[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": { + **meta_info, + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + }, } else: assert isinstance(recv_obj, BatchEmbeddingOut) out_dict = { "embedding": recv_obj.embeddings[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": meta_info, } state.out_list.append(out_dict) - state.finished = recv_obj.finished_reason[i] is not None + state.finished = recv_obj.finished_reasons[i] is not None state.event.set() if self.enable_metrics: - completion_tokens = recv_obj.meta_info[i]["completion_tokens"] + completion_tokens = ( + recv_obj.completion_tokens[i] + if recv_obj.completion_tokens + else 0 + ) if state.first_token_time is None: state.first_token_time = time.time() @@ -647,7 +672,7 @@ async def handle_loop(self): if state.finished: self.metrics_collector.inc_prompt_tokens( - recv_obj.meta_info[i]["prompt_tokens"] + recv_obj.prompt_tokens[i] ) self.metrics_collector.inc_generation_tokens( completion_tokens @@ -696,57 +721,73 @@ async def handle_loop(self): def convert_logprob_style( self, - ret: dict, - return_logprob: bool, + meta_info: dict, top_logprobs_num: int, return_text_in_logprobs: bool, + recv_obj: BatchStrOut, + recv_obj_index: int, ): - if return_logprob: - ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs + meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.input_token_logprobs_val[recv_obj_index], + recv_obj.input_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.output_token_logprobs_val[recv_obj_index], + recv_obj.output_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[ + recv_obj_index + ] + + if top_logprobs_num > 0: + meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.input_top_logprobs_val[recv_obj_index], + recv_obj.input_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, ) - ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs + meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.output_top_logprobs_val[recv_obj_index], + recv_obj.output_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, ) - if top_logprobs_num > 0: - ret["meta_info"]["input_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["input_top_logprobs"], - return_text_in_logprobs, - ) - ) - ret["meta_info"]["output_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs - ) - ) - return ret - def detokenize_logprob_tokens( - self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, ): - # TODO(lianmin): This should run on DetokenizerManager if not decode_to_text: - return [(logprob, token_id, None) for logprob, token_id in token_logprobs] - - assert self.tokenizer is not None - token_ids = [tid for _, tid in token_logprobs] - token_texts = self.tokenizer.batch_decode(token_ids) - return [ - (logprob, token_id, token_text) - for (logprob, token_id), token_text in zip(token_logprobs, token_texts) - ] + return [ + (logprob, token_id, None) + for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx) + ] + else: + assert self.tokenizer is not None + token_texts = self.tokenizer.batch_decode(token_logprobs_idx) + return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) - def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): + def detokenize_top_logprobs_tokens( + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, + ): # TODO: The current implementation only batches the detokenization for top-k tokens per single position. # We should batch all top-k tokens in all positions. - for i, token_top_logprobs in enumerate(top_logprobs): - if token_top_logprobs: - top_logprobs[i] = self.detokenize_logprob_tokens( - token_top_logprobs, decode_to_text + ret = [] + for i in range(len(token_logprobs_val)): + if token_logprobs_val[i]: + ret.append( + self.detokenize_logprob_tokens( + token_logprobs_val[i], token_logprobs_idx[i], decode_to_text + ) ) - return top_logprobs + else: + ret.append(None) + return ret class SignalHandler: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 27043cc9a7d..77efba89212 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -400,9 +400,14 @@ def replay(self, forward_batch: ForwardBatch): forward_mode=ForwardMode.DECODE, top_logprobs_nums=forward_batch.top_logprobs_nums, ) - logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( + ( + logits_output.output_top_logprobs_val, + logits_output.output_top_logprobs_idx, + ) = LogitsProcessor.get_top_logprobs( next_token_logprobs, logits_metadata - )[1] + )[ + 2:4 + ] else: logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 514bf31a68b..32c6e08b69f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -720,13 +720,13 @@ def run_and_check_memory_leak( # Clean up everything kill_process_tree(process.pid) - kill_process_tree(process.pid) stdout.close() stderr.close() if os.path.exists(STDOUT_FILENAME): os.remove(STDOUT_FILENAME) if os.path.exists(STDERR_FILENAME): os.remove(STDERR_FILENAME) + kill_process_tree(process.pid) t.join() # Assert success @@ -734,7 +734,7 @@ def run_and_check_memory_leak( has_leak = False has_abort = False for line in output_lines: - if "The server is fired" in line: + if "Uvicorn running" in line: has_new_server = True if "leak" in line: has_leak = True diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 1a857d0da6e..adb5c18fbe2 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -95,15 +95,6 @@ def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1) self.assertIsInstance(js_obj["name"], str) self.assertIsInstance(js_obj["population"], int) - # Make sure jump forward is triggered - # NOTE: The overlap scheduler does not support jump forward so we only do this test - # when --disable-overlap-schedule is set. - if self.check_jump_forward: - self.assertGreater( - ret["meta_info"]["completion_tokens"], - ret["meta_info"]["completion_tokens_wo_jump_forward"], - ) - def test_json_generate(self): self.run_decode(json_schema=self.json_schema) From a1e697b25b31287b67afe009a61f803b2fd6592f Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 8 Dec 2024 15:24:02 -0800 Subject: [PATCH 56/60] [router] Improve cleanup logic (#2411) --- rust/py_src/sglang_router/launch_server.py | 117 +++++++-------------- rust/py_test/test_launch_server.py | 67 ++++++++---- 2 files changed, 85 insertions(+), 99 deletions(-) diff --git a/rust/py_src/sglang_router/launch_server.py b/rust/py_src/sglang_router/launch_server.py index ec86e8b2adb..9c482e48986 100644 --- a/rust/py_src/sglang_router/launch_server.py +++ b/rust/py_src/sglang_router/launch_server.py @@ -10,12 +10,12 @@ from typing import List import requests +from setproctitle import setproctitle from sglang_router.launch_router import RouterArgs, launch_router from sglang.srt.server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available -from sglang.utils import get_exception_traceback def setup_logger(): @@ -34,10 +34,12 @@ def setup_logger(): return logger +logger = setup_logger() + + # Create new process group def run_server(server_args, dp_rank): - os.setpgrp() # Create new process group - + setproctitle(f"sglang::server") # Set SGLANG_DP_RANK environment variable os.environ["SGLANG_DP_RANK"] = str(dp_rank) @@ -58,36 +60,6 @@ def launch_server_process( return proc -def cleanup_processes(processes: List[mp.Process]): - logger = logging.getLogger("router") - logger.info("Cleaning up processes...") - for proc in processes: - if proc.is_alive(): - try: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - proc.join(timeout=3) - if proc.is_alive(): - logger.warning( - f"Process {proc.pid} did not terminate gracefully, force killing..." - ) - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - except ProcessLookupError: - pass - - -def setup_signal_handlers(cleanup_func): - """Setup handlers for various termination signals.""" - - def signal_handler(signum, frame): - cleanup_func() - sys.exit(1) - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - if hasattr(signal, "SIGQUIT"): - signal.signal(signal.SIGQUIT, signal_handler) - - def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool: """Wait for server to be healthy by checking /health endpoint.""" start_time = time.time() @@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]: return available_ports +def cleanup_processes(processes: List[mp.Process]): + for process in processes: + process.terminate() + + def main(): - logger = setup_logger() # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes mp.set_start_method("spawn") @@ -148,52 +124,33 @@ def main(): # Start server processes server_processes = [] - try: - for i, worker_port in enumerate(worker_ports): - logger.info(f"Launching DP server process {i} on port {worker_port}") - proc = launch_server_process(server_args, worker_port, i) - server_processes.append(proc) - - # Setup cleanup handler - setup_signal_handlers(lambda: cleanup_processes(server_processes)) - - # Wait for all servers to be healthy - all_healthy = True - - for port in worker_ports: - if not wait_for_server_health(server_args.host, port): - logger.error(f"Server on port {port} failed to become healthy") - all_healthy = False - break - - if not all_healthy: - logger.error("Not all servers are healthy. Shutting down...") - cleanup_processes(server_processes) - sys.exit(1) - - logger.info("All servers are healthy. Starting router...") - - # Update router args with worker URLs - router_args.worker_urls = [ - f"http://{server_args.host}:{port}" for port in worker_ports - ] - - # Start the router - router = launch_router(router_args) - - if router is None: - logger.error("Failed to start router. Shutting down...") - cleanup_processes(server_processes) - sys.exit(1) - - except KeyboardInterrupt: - logger.info("Received shutdown signal...") - except Exception as e: - logger.error(f"Error occurred: {e}") - logger.error(get_exception_traceback()) - finally: - logger.info("Cleaning up processes...") - cleanup_processes(server_processes) + for i, worker_port in enumerate(worker_ports): + logger.info(f"Launching DP server process {i} on port {worker_port}") + proc = launch_server_process(server_args, worker_port, i) + server_processes.append(proc) + + signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes)) + signal.signal( + signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes) + ) + signal.signal( + signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes) + ) + + for port in worker_ports: + if not wait_for_server_health(server_args.host, port): + logger.error(f"Server on port {port} failed to become healthy") + break + + logger.info("All servers are healthy. Starting router...") + + # Update router args with worker URLs + router_args.worker_urls = [ + f"http://{server_args.host}:{port}" for port in worker_ports + ] + + # Start the router + router = launch_router(router_args) if __name__ == "__main__": diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index 68945d8fb52..2591abb5cdf 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -6,7 +6,6 @@ import requests -from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -104,23 +103,52 @@ def popen_launch_server( return process +def terminate_and_wait(process, timeout=300): + """Terminate a process and wait until it is terminated. + + Args: + process: subprocess.Popen object + timeout: maximum time to wait in seconds + + Raises: + TimeoutError: if process does not terminate within timeout + """ + if process is None: + return + + process.terminate() + start_time = time.time() + + while process.poll() is None: + print(f"Terminating process {process.pid}") + if time.time() - start_time > timeout: + raise TimeoutError( + f"Process {process.pid} failed to terminate within {timeout}s" + ) + time.sleep(1) + + print(f"Process {process.pid} is successfully terminated") + + class TestLaunchServer(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = None - cls.other_process = [] - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - for process in cls.other_process: - kill_process_tree(process.pid) - - def test_mmlu(self): + def setUp(self): + self.model = DEFAULT_MODEL_NAME_FOR_TEST + self.base_url = DEFAULT_URL_FOR_TEST + self.process = None + self.other_process = [] + + def tearDown(self): + print("Running tearDown...") + if self.process: + terminate_and_wait(self.process) + for process in self.other_process: + terminate_and_wait(process) + print("tearDown done") + + def test_1_mmlu(self): + print("Running test_1_mmlu...") # DP size = 2 - TestLaunchServer.process = popen_launch_router( + self.process = popen_launch_router( self.model, self.base_url, dp_size=2, @@ -144,9 +172,10 @@ def test_mmlu(self): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) - def test_add_and_remove_worker(self): + def test_2_add_and_remove_worker(self): + print("Running test_2_add_and_remove_worker...") # DP size = 1 - TestLaunchServer.process = popen_launch_router( + self.process = popen_launch_router( self.model, self.base_url, dp_size=1, @@ -159,7 +188,7 @@ def test_add_and_remove_worker(self): worker_process = popen_launch_server( self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ) - TestLaunchServer.other_process.append(worker_process) + self.other_process.append(worker_process) # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy with requests.Session() as session: From 2a717c5078ed5feb7c8df70943e25d27e50a89eb Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 8 Dec 2024 16:58:41 -0800 Subject: [PATCH 57/60] [Router] fix interrupt from terminal (#2413) --- rust/py_src/sglang_router/launch_server.py | 32 +++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/rust/py_src/sglang_router/launch_server.py b/rust/py_src/sglang_router/launch_server.py index 9c482e48986..52f323acc47 100644 --- a/rust/py_src/sglang_router/launch_server.py +++ b/rust/py_src/sglang_router/launch_server.py @@ -39,6 +39,35 @@ def setup_logger(): # Create new process group def run_server(server_args, dp_rank): + """ + Note: + + 1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously. + This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes. + + Terminal (PGID=100) + └── Main Python Process (PGID=100) + └── Server Process 1 (PGID=100) + └── Scheduler 1 + └── Detokenizer 1 + └── Server Process 2 (PGID=100) + └── Scheduler 2 + └── Detokenizer 2 + + 2. With os.setpgrp(), the main Python process and its children are in a separate group. Now: + + Terminal (PGID=100) + └── Main Python Process (PGID=200) + └── Server Process 1 (PGID=300) + └── Scheduler 1 + └── Detokenizer 1 + └── Server Process 2 (PGID=400) + └── Scheduler 2 + └── Detokenizer 2 + """ + # create new process group + os.setpgrp() + setproctitle(f"sglang::server") # Set SGLANG_DP_RANK environment variable os.environ["SGLANG_DP_RANK"] = str(dp_rank) @@ -91,11 +120,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]: def cleanup_processes(processes: List[mp.Process]): for process in processes: + logger.info(f"Terminating process {process.pid}") process.terminate() + logger.info("All processes terminated") def main(): - # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes mp.set_start_method("spawn") From 6387098f5f98101ee103732efe8da9d6cb54d92d Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 8 Dec 2024 17:17:37 -0800 Subject: [PATCH 58/60] [router] add health checking in router init (#2393) --- rust/Cargo.lock | 2 + rust/Cargo.toml | 2 +- rust/py_src/sglang_router/launch_server.py | 7 -- rust/src/router.rs | 123 ++++++++++++++++++--- rust/src/server.rs | 23 ++-- 5 files changed, 126 insertions(+), 31 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8e7f306589f..dc9c46a7146 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -851,6 +851,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1986,6 +1987,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2 0.4.6", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index d49af81cf56..d20a381ee7b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -19,7 +19,7 @@ serde = { version = "1.0", features = ["derive"] } clap = { version = "4.4", features = ["derive"] } bytes = "1.8.0" rand = "0.8.5" -reqwest = { version = "0.12.8", features = ["stream"] } +reqwest = { version = "0.12.8", features = ["stream", "blocking"] } futures-util = "0.3" serde_json = "1.0" pyo3 = { version = "0.22.5", features = ["extension-module"] } diff --git a/rust/py_src/sglang_router/launch_server.py b/rust/py_src/sglang_router/launch_server.py index 52f323acc47..6ee19241542 100644 --- a/rust/py_src/sglang_router/launch_server.py +++ b/rust/py_src/sglang_router/launch_server.py @@ -167,13 +167,6 @@ def main(): signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes) ) - for port in worker_ports: - if not wait_for_server_health(server_args.host, port): - logger.error(f"Server on port {port} failed to become healthy") - break - - logger.info("All servers are healthy. Starting router...") - # Update router args with worker URLs router_args.worker_urls = [ f"http://{server_args.host}:{port}" for port in worker_ports diff --git a/rust/src/router.rs b/rust/src/router.rs index acba974972c..615a9550ef3 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -93,7 +93,7 @@ pub enum Router { }, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum PolicyConfig { RandomConfig, RoundRobinConfig, @@ -127,9 +127,14 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String { return "".to_string(); } + impl Router { - pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { - match policy_config { + pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { + // Wait until all workers are healthy + Self::wait_for_healthy_workers(&worker_urls, 300, 10)?; + + // Create router based on policy... + Ok(match policy_config { PolicyConfig::RandomConfig => Router::Random { worker_urls: Arc::new(RwLock::new(worker_urls)), }, @@ -196,7 +201,7 @@ impl Router { _eviction_thread: Some(eviction_thread), } } - } + }) } pub fn get_first(&self) -> Option { @@ -213,6 +218,59 @@ impl Router { } } + fn wait_for_healthy_workers( + worker_urls: &[String], + timeout_secs: u64, + interval_secs: u64, + ) -> Result<(), String> { + let start_time = std::time::Instant::now(); + let sync_client = reqwest::blocking::Client::new(); + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + return Err(format!( + "Timeout {}s waiting for workers to become healthy", + timeout_secs + )); + } + + let mut all_healthy = true; + let mut unhealthy_workers = Vec::new(); + + for url in worker_urls { + match sync_client.get(&format!("{}/health", url)).send() { + Ok(res) => { + if !res.status().is_success() { + info!( + "Worker {} health check is pending with status: {}.", + url, + res.status() + ); + all_healthy = false; + unhealthy_workers.push((url, format!("Status: {}", res.status()))); + } + } + Err(e) => { + info!("Worker {} health check is pending with error: {}", url, e); + all_healthy = false; + unhealthy_workers.push((url, format!("Error: {}", e))); + } + } + } + + if all_healthy { + info!("All workers are healthy"); + return Ok(()); + } else { + info!("Unhealthy workers:"); + for (url, reason) in &unhealthy_workers { + info!(" {} - {}", url, reason); + } + thread::sleep(Duration::from_secs(interval_secs)); + } + } + } + pub async fn dispatch( &self, client: &reqwest::Client, @@ -386,7 +444,7 @@ impl Router { } } - pub async fn add_worker(&self, worker_url: String) -> HttpResponse { + pub async fn add_worker(&self, worker_url: String) -> Result { let interval_secs = 10; // check every 10 seconds let timeout_secs = 300; // 5 minutes @@ -395,7 +453,7 @@ impl Router { loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { - return HttpResponse::InternalServerError().body(format!( + return Err(format!( "Timeout {}s waiting for worker {} to become healthy", timeout_secs, worker_url )); @@ -411,19 +469,40 @@ impl Router { info!("Worker {} health check passed", worker_url); let mut urls = worker_urls.write().unwrap(); if urls.contains(&worker_url) { - return HttpResponse::BadRequest() - .body(format!("Worker {} already exists", worker_url)); + return Err(format!("Worker {} already exists", worker_url)); } info!("Added worker: {}", worker_url); urls.push(worker_url.clone()); } } - return HttpResponse::Ok() - .body(format!("Successfully added worker: {}", worker_url)); + + // If cache aware, initialize the queues for the new worker + if let Router::CacheAware { + running_queue, + processed_queue, + tree, + .. + } = self + { + // Add worker to running queue with initial count of 0 + running_queue.lock().unwrap().insert(worker_url.clone(), 0); + + // Add worker to processed queue with initial count of 0 + processed_queue + .lock() + .unwrap() + .insert(worker_url.clone(), 0); + + // Add worker to tree + tree.lock().unwrap().insert(&"".to_string(), &worker_url); + } + + return Ok(format!("Successfully added worker: {}", worker_url)); } else { info!( - "Worker {} health check failed with status: {}. The worker might still be starting up.", - worker_url, res.status() + "Worker {} health check is pending with status: {}.", + worker_url, + res.status() ); // if the url does not have http or https prefix, warn users if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") @@ -436,7 +515,10 @@ impl Router { } } Err(e) => { - info!("Worker {} health check failed: {}", worker_url, e); + info!( + "Worker {} health check is pending with error: {}", + worker_url, e + ); // if the url does not have http or https prefix, warn users if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { @@ -463,9 +545,20 @@ impl Router { } // if cache aware, remove the worker from the tree - if let Router::CacheAware { tree, .. } = self { + if let Router::CacheAware { + tree, + running_queue, + processed_queue, + .. + } = self + { tree.lock().unwrap().remove_tenant(&worker_url); - info!("Removed worker from tree: {}", worker_url); + running_queue.lock().unwrap().remove(&worker_url); + processed_queue.lock().unwrap().remove(&worker_url); + info!( + "Removed worker from tree and cleaned up queues: {}", + worker_url + ); } } } diff --git a/rust/src/server.rs b/rust/src/server.rs index d7ec6ebc6e5..8a0eb1547d6 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -20,7 +20,10 @@ impl AppState { policy_config: PolicyConfig, ) -> Self { // Create router based on policy - let router = Router::new(worker_urls, policy_config); + let router = match Router::new(worker_urls, policy_config) { + Ok(router) => router, + Err(error) => panic!("Failed to create router: {}", error), + }; Self { router, client } } @@ -141,7 +144,11 @@ async fn add_worker( .body("Worker URL required. Provide 'url' query parameter") } }; - data.router.add_worker(worker_url).await + + match data.router.add_worker(worker_url).await { + Ok(message) => HttpResponse::Ok().body(message), + Err(error) => HttpResponse::BadRequest().body(error), + } } #[post("/remove_worker")] @@ -187,20 +194,20 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ) .init(); - info!("Starting server on {}:{}", config.host, config.port); - info!("Worker URLs: {:?}", config.worker_urls); - info!("Policy Config: {:?}", config.policy_config); - let client = reqwest::Client::builder() .build() .expect("Failed to create HTTP client"); let app_state = web::Data::new(AppState::new( - config.worker_urls, + config.worker_urls.clone(), client, - config.policy_config, + config.policy_config.clone(), )); + info!("✅ Starting router on {}:{}", config.host, config.port); + info!("✅ Serving Worker URLs: {:?}", config.worker_urls); + info!("✅ Policy Config: {:?}", config.policy_config); + HttpServer::new(move || { App::new() .app_data(app_state.clone()) From 27f7bed7a75b52538a2a4de69054f1dd19e1455c Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 8 Dec 2024 21:17:31 -0800 Subject: [PATCH 59/60] reduce watchdog interval to 5s (#2410) --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 4788565ac01..29b98df2efa 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -572,7 +572,7 @@ def create_handle_loop(self): async def sigterm_watchdog(self): while not self.gracefully_exit: - await asyncio.sleep(60) + await asyncio.sleep(5) # drain requests while True: From 3844feb9bb1cdd1ee59653b85e3b40e8a4d107d1 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:46:10 +0800 Subject: [PATCH 60/60] Add a unittest for fused_moe (#2416) --- benchmark/kernels/fused_moe_triton/README.md | 8 +- ...280,device_name=NVIDIA_A800-SXM4-80GB.json | 146 ++++++++++++++++++ ...640,device_name=NVIDIA_A800-SXM4-80GB.json | 146 ++++++++++++++++++ test/srt/run_suite.py | 1 + test/srt/test_fused_moe.py | 126 +++++++++++++++ 5 files changed, 425 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json create mode 100644 python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json create mode 100644 test/srt/test_fused_moe.py diff --git a/benchmark/kernels/fused_moe_triton/README.md b/benchmark/kernels/fused_moe_triton/README.md index ba29ede5099..2a3e37f6874 100644 --- a/benchmark/kernels/fused_moe_triton/README.md +++ b/benchmark/kernels/fused_moe_triton/README.md @@ -10,7 +10,7 @@ Example usage: ```bash # Tune Qwen2-57B with FP8 and TP=4 python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ - --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --model Qwen/Qwen2-57B-A14B-Instruct \ --tp-size 4 \ --dtype fp8_w8a8 \ --tune @@ -34,7 +34,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri # Compare with FP8 mode for Qwen2-57B python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ - --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --model Qwen/Qwen2-57B-A14B-Instruct \ --use-fp8 # Compare with custom TP size @@ -43,3 +43,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri ``` The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). + +- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel. + +Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`. diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 00000000000..283ffd8ff1d --- /dev/null +++ b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 00000000000..8a18afe7d6d --- /dev/null +++ b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5035810f86a..cb6a60612dd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -15,6 +15,7 @@ "test_double_sparsity.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", + "test_fused_moe.py", "test_get_weights_by_name.py", "test_gguf.py", "test_input_embeddings.py", diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py new file mode 100644 index 00000000000..7b50c551a82 --- /dev/null +++ b/test/srt/test_fused_moe.py @@ -0,0 +1,126 @@ +import unittest + +import torch +from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe + + +class TestFusedMOE(unittest.TestCase): + NUM_EXPERTS = [8, 64] + TOP_KS = [2, 6] + + def torch_naive_moe(self, a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[ + i + ].transpose(0, 1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): + if use_fp8_w8a8: + # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 + capability = torch.cuda.get_device_capability() + if not (capability[0] >= 9 or capability == (8, 9)): + return + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + score = torch.randn((m, e), device="cuda", dtype=dtype) + + w1_scale = torch.randn(e, dtype=torch.float32, device="cuda") + w2_scale = torch.randn(e, dtype=torch.float32, device="cuda") + a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") + a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") + + sglang_output = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + vllm_output = fused_moe_vllm( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) + + else: + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + torch_output = self.torch_naive_moe(a, w1, w2, score, topk) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + def test_various_configurations(self): + m_values = [1, 33, 64, 222, 1024 * 128] + n_values = [128, 1024, 2048] + k_values = [128, 511, 1024] + dtypes = [torch.float16, torch.bfloat16] + fp8_modes = [False, True] + + for m in m_values: + for n in n_values: + for k in k_values: + for e in self.NUM_EXPERTS: + for topk in self.TOP_KS: + for dtype in dtypes: + for use_fp8_w8a8 in fp8_modes: + with self.subTest( + m=m, + n=n, + k=k, + e=e, + topk=topk, + dtype=dtype, + fp8=use_fp8_w8a8, + ): + self._test_case( + m, + n, + k, + e, + topk, + dtype, + use_fp8_w8a8=use_fp8_w8a8, + ) + + +if __name__ == "__main__": + unittest.main()