Skip to content

Commit

Permalink
Add Phi3VCausalLM
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi03071991 committed Jan 8, 2025
1 parent 39a27c5 commit d8b64c5
Show file tree
Hide file tree
Showing 5 changed files with 533 additions and 42 deletions.
3 changes: 2 additions & 1 deletion python/sglang/srt/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.phi3v import Phi3VConfig
from sglang.srt.configs.phi3v import Phi3VCLIPVisionConfig, Phi3VConfig
from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig

__all__ = [
"ExaoneConfig",
"Qwen2VLConfig",
"Qwen2VLVisionConfig",
"Phi3VConfig",
"Phi3VCLIPVisionConfig",
]
2 changes: 1 addition & 1 deletion python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
or "LlavaPhi3VForCausalLM" in model_architectures
or "Phi3VForCausalLM" in model_architectures
):
return True
else:
Expand Down
41 changes: 41 additions & 0 deletions python/sglang/srt/configs/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

""" Phi-3-V model configuration"""

from dataclasses import dataclass

from transformers import CLIPVisionConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

Expand Down Expand Up @@ -221,3 +223,42 @@ def _rope_scaling_validation(self):
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
)


@dataclass
class Phi3VCLIPVisionConfig:
attention_dropout: float = 0.0
dropout: float = 0.0
hidden_act: str = "quick_gelu"
hidden_size: int = 1024
image_size: int = 336
initializer_factor: float = 1.0
initializer_range: float = 0.02
intermediate_size: int = 4096
layer_norm_eps: float = 1e-5
num_attention_heads: int = 16
num_channels: int = 3
num_hidden_layers: int = 24
patch_size: int = 14
projection_dim: int = 768

def to_transformers_config(self) -> CLIPVisionConfig:
"""
Converts this dataclass into a Hugging Face CLIPVisionConfig object.
"""
return CLIPVisionConfig(
attention_dropout=self.attention_dropout,
dropout=self.dropout,
hidden_act=self.hidden_act,
hidden_size=self.hidden_size,
image_size=self.image_size,
initializer_factor=self.initializer_factor,
initializer_range=self.initializer_range,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
num_attention_heads=self.num_attention_heads,
num_channels=self.num_channels,
num_hidden_layers=self.num_hidden_layers,
patch_size=self.patch_size,
projection_dim=self.projection_dim,
)
42 changes: 2 additions & 40 deletions python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector

from sglang.srt.configs import Phi3VConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.mm_utils import (
Expand All @@ -40,7 +39,7 @@
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, Phi3ForCausalLM
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM

Expand Down Expand Up @@ -555,41 +554,4 @@ def __init__(
)


class LlavaPhi3VForCausalLM(LlavaBaseForCausalLM):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()

self.config = config
self.vision_tower = None

if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = Phi3VConfig(self.config._name_or_path)

self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size

if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 32044

self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = Phi3ForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)


EntryClass = [
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM,
LlavaPhi3VForCausalLM,
]
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
Loading

0 comments on commit d8b64c5

Please sign in to comment.