Skip to content

Commit

Permalink
More precise profiling
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Jan 3, 2025
1 parent b5020c2 commit 347f718
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
6 changes: 3 additions & 3 deletions vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def _get_num_frame_tokens(

def _get_max_frame_tokens(self) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._vision_encoder_info
spatial_pool_stride = hf_config.spatial_pool_stride

patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = patch_grid_length / hf_config.spatial_pool_stride
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
pooled_grid_length = patch_grid_length / spatial_pool_stride

return int(pooled_grid_length * pooled_grid_length)

Expand Down
16 changes: 11 additions & 5 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,9 @@ def _get_num_frame_tokens(

def _get_max_frame_tokens(self) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._vision_encoder_info

patch_grid_length = vision_encoder_info.get_patch_grid_length()

spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)

patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
pooled_grid_length = patch_grid_length / spatial_pool_stride

return math.ceil(pooled_grid_length) * math.ceil(pooled_grid_length)
Expand All @@ -185,9 +183,17 @@ def _get_max_video_frames(
num_images: int = 0,
num_videos: int = 1,
) -> int:
hf_config = self._get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)

max_total_tokens = self.ctx.model_config.max_model_len
max_total_frames = int(max_total_tokens / self._get_max_frame_tokens())
return (max_total_frames - num_images) // max(num_videos, 1)

# How many tokens are one image worth relative to one video frame
i2f = spatial_pool_stride * spatial_pool_stride
max_total_frames -= num_images * i2f

return max(max_total_frames, 0) // max(num_videos, 1)

def _get_max_video_tokens(self) -> int:
return self._get_max_frame_tokens() * self._get_max_video_frames()
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,9 +772,17 @@ def _get_max_video_frames(
num_images: int = 0,
num_videos: int = 1,
) -> int:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
temporal_patch_size = hf_config.vision_config.temporal_patch_size

max_total_tokens = self.ctx.model_config.max_model_len
max_total_frames = int(max_total_tokens / self._get_max_image_tokens())
return (max_total_frames - num_images) // max(num_videos, 1)

# How many tokens are one image worth relative to one video frame
i2f = temporal_patch_size * temporal_patch_size
max_total_frames -= num_images * i2f

return max(max_total_frames, 0) // max(num_videos, 1)

def _get_max_video_tokens(self) -> int:
return self._get_max_image_tokens() * self._get_max_video_frames()
Expand Down

0 comments on commit 347f718

Please sign in to comment.