Skip to content

Commit

Permalink
[feat] Enable chunked prefill for llava-onevision (#2412)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Dec 9, 2024
1 parent 641b7d0 commit 8586b72
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 20 deletions.
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class ImageInputs:
image_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
image_pad_len: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
Expand Down Expand Up @@ -181,6 +182,7 @@ def merge(self, other):
optional_args = [
"image_sizes",
"image_offsets",
"image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
Expand Down
15 changes: 9 additions & 6 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,20 @@ def __init__(
)

if self.is_multimodal:
server_args.chunked_prefill_size = -1
self.mem_fraction_static *= 0.95
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"and turn off chunked prefill "
f"because this is a multimodal model."
)
if self.model_config.hf_config.architectures == [
"MllamaForConditionalGeneration"
]:
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
server_args.chunked_prefill_size = -1
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
]:
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True

# Global vars
Expand Down
51 changes: 37 additions & 14 deletions python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
else:
image_aspect_ratio = "anyres"
offset_list = []
image_inputs.image_pad_len = []
for image_idx, image_s in enumerate(image_sizes):
if len(image_sizes) > 16:
# 2x2 pooling with stride 2
Expand Down Expand Up @@ -103,6 +104,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
image_inputs.image_pad_len.append(new_image_feature_len)

image_inputs.image_offsets = offset_list
return input_ids
Expand Down Expand Up @@ -134,6 +136,14 @@ def forward(
image_inputs = forward_batch.image_inputs

if forward_batch.forward_mode.is_extend():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)

# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)

# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
Expand All @@ -142,18 +152,12 @@ def forward(
if im and im.modalities is not None:
modalities_list.extend(im.modalities)
if im and im.image_offsets:
max_image_offset.append(max(im.image_offsets))
max_image_offset.append(
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
)
else:
max_image_offset.append(-1)

# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)

# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)

start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)

Expand Down Expand Up @@ -350,25 +354,44 @@ def forward(

# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
if not need_vision[i]:
continue

start_idx = extend_start_loc_cpu[i]
seq_len = extend_seq_lens[i]
prefix_len = prefix_lens_cpu[i]

# Multiple images
for j, image_offset in enumerate(image_inputs[i].image_offsets):
if image_offset < prefix_len:
for image_idx, image_offset in enumerate(
image_inputs[i].image_offsets
):
if (
image_offset + image_inputs[i].image_pad_len[image_idx]
<= prefix_len
):
continue
if image_offset >= prefix_len + seq_len:
break

tmp_image_feature = image_features[pt][j]
tmp_image_feature = image_features[pt][image_idx]
pad_len = tmp_image_feature.shape[0]

left_idx = start_idx + (image_offset - prefix_len)
right_idx = start_idx + (image_offset - prefix_len) + pad_len
input_offset = image_offset - prefix_len
left_idx = start_idx + input_offset
right_idx = left_idx + pad_len
assert right_idx > start_idx
if input_offset < 0:
left_idx = start_idx
tmp_image_feature = tmp_image_feature[-input_offset:]
if right_idx > start_idx + seq_len:
tmp_image_feature = tmp_image_feature[
: start_idx + seq_len - right_idx
]
right_idx = start_idx + seq_len
try:
input_embeds[left_idx:right_idx] = tmp_image_feature
except RuntimeError as e:
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"test_triton_attention_kernels.py",
"test_triton_attention_backend.py",
"test_update_weights_from_disk.py",
"test_vision_chunked_prefill.py",
"test_vision_openai_server.py",
"test_session_control.py",
],
Expand Down
173 changes: 173 additions & 0 deletions test/srt/test_vision_chunked_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""
Usage:
python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill
"""

import base64
import io
import os
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Union

import numpy as np
import requests
from decord import VideoReader, cpu
from PIL import Image

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestVisionChunkedPrefill(unittest.TestCase):
def prepare_video_messages(self, video_path, max_frames_num=8):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(
0, total_frame_num - 1, max_frames_num, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()

base64_frames = []
for frame in frames:
pil_img = Image.fromarray(frame)
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
base64_frames.append(base64_str)

messages = [{"role": "user", "content": []}]
frame_format = {
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{}"},
"modalities": "video",
}

for base64_frame in base64_frames:
frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format(
base64_frame
)
messages[0]["content"].append(frame_format.copy())

prompt = {"type": "text", "text": "Please describe the video briefly."}
messages[0]["content"].append(prompt)

return messages

def get_prompt_from_messages(self, messages):
text = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
)
image_data = []
for content in messages[0]["content"]:
if content["type"] == "image_url":
text += "<image>\n"
image_data.append(content["image_url"]["url"])
text += "Please describe the video briefly.<|im_end|>\n<|im_start|>assistant\n"
return text, image_data

def generate(self, text, image_data):
response = requests.post(
self.base_url + "/generate",
json={
"text": text,
"image_data": image_data,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"modalities": ["multi-images"],
},
).json()
return response["text"]

def generate_for_video(self, batch, num_frame) -> Union[str, list[str]]:
# prepare the video input about Steven introducing ipod nano
url = "https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache")
file_path = os.path.join(cache_dir, "jobs.mp4")
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)

if not batch:
assert isinstance(num_frame, int)
messages = self.prepare_video_messages(file_path, max_frames_num=num_frame)
text, image_data = self.get_prompt_from_messages(messages)
return self.generate(text, image_data)
else:
assert isinstance(num_frame, list)
func_args = []
for max_frames_num in num_frame:
messages = self.prepare_video_messages(
file_path,
max_frames_num=max_frames_num,
)
text, image_data = self.get_prompt_from_messages(messages)
func_args.append((text, image_data))

with ThreadPoolExecutor(max_workers=10) as executor:
responses = list(executor.map(lambda p: self.generate(*p), func_args))

return responses

def run_generate(self, chunked_prefill_size, batch, num_frame):
# launch server
model = "lmms-lab/llava-onevision-qwen2-7b-ov"
# model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
self.base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
f"{chunked_prefill_size}",
],
)
try:
return self.generate_for_video(batch, num_frame)
finally:
kill_process_tree(process.pid)

def test_chunked_prefill(self):
output_chunked = self.run_generate(
chunked_prefill_size=1024, batch=False, num_frame=1
)
output_no_chunked = self.run_generate(
chunked_prefill_size=-1, batch=False, num_frame=1
)

print("output with chunked prefill:")
print(output_chunked)
print("output without chunked prefill:")
print(output_no_chunked)
assert output_chunked == output_no_chunked

output_chunked = self.run_generate(
chunked_prefill_size=1024, batch=True, num_frame=[2, 6, 8, 10]
)
output_no_chunked = self.run_generate(
chunked_prefill_size=-1, batch=True, num_frame=[2, 6, 8, 10]
)

print("output with chunked prefill:")
print(output_chunked)
print("output without chunked prefill:")
print(output_no_chunked)
assert output_chunked == output_no_chunked


if __name__ == "__main__":
unittest.main()

0 comments on commit 8586b72

Please sign in to comment.