diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index bb9eb181611..25770c6ae93 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -129,6 +129,7 @@ class ImageInputs: image_hashes: Optional[list] = None image_sizes: Optional[list] = None image_offsets: Optional[list] = None + image_pad_len: Optional[list] = None pad_values: Optional[list] = None modalities: Optional[list] = None num_image_tokens: Optional[int] = None @@ -181,6 +182,7 @@ def merge(self, other): optional_args = [ "image_sizes", "image_offsets", + "image_pad_len", # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images "aspect_ratio_ids", "aspect_ratio_mask", diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ebda816dbaf..98b140614cc 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -111,17 +111,20 @@ def __init__( ) if self.is_multimodal: - server_args.chunked_prefill_size = -1 self.mem_fraction_static *= 0.95 - logger.info( - f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " - f"and turn off chunked prefill " - f"because this is a multimodal model." - ) + if self.model_config.hf_config.architectures == [ + "MllamaForConditionalGeneration" + ]: + logger.info("Automatically turn off --chunked-prefill-size for mllama.") + server_args.chunked_prefill_size = -1 # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" ]: + logger.info( + "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl." + ) + server_args.chunked_prefill_size = -1 server_args.disable_radix_cache = True # Global vars diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 4c62dbb25f1..c8ce9302b4f 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -57,6 +57,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): else: image_aspect_ratio = "anyres" offset_list = [] + image_inputs.image_pad_len = [] for image_idx, image_s in enumerate(image_sizes): if len(image_sizes) > 16: # 2x2 pooling with stride 2 @@ -103,6 +104,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + input_ids[offset + 1 :] ) offset_list.append(offset) + image_inputs.image_pad_len.append(new_image_feature_len) image_inputs.image_offsets = offset_list return input_ids @@ -134,6 +136,14 @@ def forward( image_inputs = forward_batch.image_inputs if forward_batch.forward_mode.is_extend(): + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + + # Embed text inputs + input_embeds = self.language_model.model.embed_tokens(input_ids) + # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size modalities_list = [] @@ -142,18 +152,12 @@ def forward( if im and im.modalities is not None: modalities_list.extend(im.modalities) if im and im.image_offsets: - max_image_offset.append(max(im.image_offsets)) + max_image_offset.append( + np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)) + ) else: max_image_offset.append(-1) - # Clamp input ids. This is because the input_ids for the image tokens are - # filled with the hash values of the image for the prefix matching in the radix attention. - # There values are useless because their embeddings will be replaced by vision embeddings anyway. - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - - # Embed text inputs - input_embeds = self.language_model.model.embed_tokens(input_ids) - start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() need_vision = start_positions <= np.array(max_image_offset) @@ -350,6 +354,7 @@ def forward( # Fill in the placeholder for the image extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() + extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu pt = 0 for i in range(bs): @@ -357,18 +362,36 @@ def forward( continue start_idx = extend_start_loc_cpu[i] + seq_len = extend_seq_lens[i] prefix_len = prefix_lens_cpu[i] # Multiple images - for j, image_offset in enumerate(image_inputs[i].image_offsets): - if image_offset < prefix_len: + for image_idx, image_offset in enumerate( + image_inputs[i].image_offsets + ): + if ( + image_offset + image_inputs[i].image_pad_len[image_idx] + <= prefix_len + ): continue + if image_offset >= prefix_len + seq_len: + break - tmp_image_feature = image_features[pt][j] + tmp_image_feature = image_features[pt][image_idx] pad_len = tmp_image_feature.shape[0] - left_idx = start_idx + (image_offset - prefix_len) - right_idx = start_idx + (image_offset - prefix_len) + pad_len + input_offset = image_offset - prefix_len + left_idx = start_idx + input_offset + right_idx = left_idx + pad_len + assert right_idx > start_idx + if input_offset < 0: + left_idx = start_idx + tmp_image_feature = tmp_image_feature[-input_offset:] + if right_idx > start_idx + seq_len: + tmp_image_feature = tmp_image_feature[ + : start_idx + seq_len - right_idx + ] + right_idx = start_idx + seq_len try: input_embeds[left_idx:right_idx] = tmp_image_feature except RuntimeError as e: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cb6a60612dd..7737951f824 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -40,6 +40,7 @@ "test_triton_attention_kernels.py", "test_triton_attention_backend.py", "test_update_weights_from_disk.py", + "test_vision_chunked_prefill.py", "test_vision_openai_server.py", "test_session_control.py", ], diff --git a/test/srt/test_vision_chunked_prefill.py b/test/srt/test_vision_chunked_prefill.py new file mode 100644 index 00000000000..f7725f17bee --- /dev/null +++ b/test/srt/test_vision_chunked_prefill.py @@ -0,0 +1,173 @@ +""" +Usage: +python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill +""" + +import base64 +import io +import os +import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Union + +import numpy as np +import requests +from decord import VideoReader, cpu +from PIL import Image + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestVisionChunkedPrefill(unittest.TestCase): + def prepare_video_messages(self, video_path, max_frames_num=8): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video briefly."} + messages[0]["content"].append(prompt) + + return messages + + def get_prompt_from_messages(self, messages): + text = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + ) + image_data = [] + for content in messages[0]["content"]: + if content["type"] == "image_url": + text += "\n" + image_data.append(content["image_url"]["url"]) + text += "Please describe the video briefly.<|im_end|>\n<|im_start|>assistant\n" + return text, image_data + + def generate(self, text, image_data): + response = requests.post( + self.base_url + "/generate", + json={ + "text": text, + "image_data": image_data, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + "modalities": ["multi-images"], + }, + ).json() + return response["text"] + + def generate_for_video(self, batch, num_frame) -> Union[str, list[str]]: + # prepare the video input about Steven introducing ipod nano + url = "https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + if not os.path.exists(file_path): + response = requests.get(url) + response.raise_for_status() + with open(file_path, "wb") as f: + f.write(response.content) + + if not batch: + assert isinstance(num_frame, int) + messages = self.prepare_video_messages(file_path, max_frames_num=num_frame) + text, image_data = self.get_prompt_from_messages(messages) + return self.generate(text, image_data) + else: + assert isinstance(num_frame, list) + func_args = [] + for max_frames_num in num_frame: + messages = self.prepare_video_messages( + file_path, + max_frames_num=max_frames_num, + ) + text, image_data = self.get_prompt_from_messages(messages) + func_args.append((text, image_data)) + + with ThreadPoolExecutor(max_workers=10) as executor: + responses = list(executor.map(lambda p: self.generate(*p), func_args)) + + return responses + + def run_generate(self, chunked_prefill_size, batch, num_frame): + # launch server + model = "lmms-lab/llava-onevision-qwen2-7b-ov" + # model = "meta-llama/Llama-3.2-11B-Vision-Instruct" + self.base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + f"{chunked_prefill_size}", + ], + ) + try: + return self.generate_for_video(batch, num_frame) + finally: + kill_process_tree(process.pid) + + def test_chunked_prefill(self): + output_chunked = self.run_generate( + chunked_prefill_size=1024, batch=False, num_frame=1 + ) + output_no_chunked = self.run_generate( + chunked_prefill_size=-1, batch=False, num_frame=1 + ) + + print("output with chunked prefill:") + print(output_chunked) + print("output without chunked prefill:") + print(output_no_chunked) + assert output_chunked == output_no_chunked + + output_chunked = self.run_generate( + chunked_prefill_size=1024, batch=True, num_frame=[2, 6, 8, 10] + ) + output_no_chunked = self.run_generate( + chunked_prefill_size=-1, batch=True, num_frame=[2, 6, 8, 10] + ) + + print("output with chunked prefill:") + print(output_chunked) + print("output without chunked prefill:") + print(output_no_chunked) + assert output_chunked == output_no_chunked + + +if __name__ == "__main__": + unittest.main()