Skip to content

Commit

Permalink
[Model] Modify MolmoForCausalLM MLP (vllm-project#11510)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored and BKitor committed Dec 30, 2024
1 parent a402ea4 commit 7cfa589
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,24 +464,27 @@ def forward(
class MolmoMLP(nn.Module):
"""Molmo's LLM mlp."""

def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
def __init__(self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
proj_name: str = "gate_up_proj") -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2

# Feed-forward input projection.
self.gate_up_proj = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)

# Molmo's LLM proj weights are already merged into the disk, while
# image_projector proj is separate. If the same proj_name were used, it
# would create ambiguity and make it difficult to support BNB and LoRA.
self.proj_name = proj_name
setattr(
self, proj_name,
MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
))
# Activation function.
self.act_fn = SiluAndMul()

Expand All @@ -497,7 +500,7 @@ def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
gate_up, _ = getattr(self, self.proj_name)(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
Expand All @@ -520,7 +523,9 @@ def __init__(
prefix=f"{prefix}.self_attn")

# MLP block.
self.mlp = MolmoMLP(config, quant_config=quant_config)
self.mlp = MolmoMLP(config,
quant_config=quant_config,
proj_name="gate_up_proj")

# LayerNorm
assert config.layer_norm_type == "rms"
Expand Down Expand Up @@ -616,6 +621,7 @@ def __init__(
config,
input_dim=vision_config.image_emb_dim,
quant_config=quant_config,
proj_name="merged_linear",
)

image_dim = vision_config.image_emb_dim * len(self.vit_layers)
Expand Down Expand Up @@ -714,8 +720,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("merged_linear", "gate_proj", 0),
("merged_linear", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
Expand Down

0 comments on commit 7cfa589

Please sign in to comment.