Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]Add BNB quantization for MolmoForCausalLM #11551

Merged
merged 6 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
Tuple, cast)

import gguf
import huggingface_hub
Expand Down Expand Up @@ -706,6 +707,8 @@ def __init__(self, load_config: LoadConfig):
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: List[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name

def _get_weight_files(
self,
Expand Down Expand Up @@ -763,9 +766,12 @@ def _prepare_weights(self, model_name_or_path: str,

def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
iterator = safetensors_weights_iterator(hf_weights_files)
else:
return pt_weights_iterator(hf_weights_files)
iterator = pt_weights_iterator(hf_weights_files)
for name, param in iterator:
# mapping weight names from transformers to vllm.
yield self.weight_mapper(name), param

def _get_quantized_weights_iterator(
self,
Expand All @@ -782,12 +788,12 @@ def _get_quantized_weights_iterator(
try:
import bitsandbytes

if bitsandbytes.__version__ < "0.44.0":
if bitsandbytes.__version__ < "0.45.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.44.0.")
"install bitsandbytes>=0.45.0.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.44.0 via "
"`pip install bitsandbytes>=0.44.0` to use "
raise ImportError("Please install bitsandbytes>=0.45.0 via "
"`pip install bitsandbytes>=0.45.0` to use "
"bitsandbytes quantizer.") from err

hf_weights_files, use_safetensors = self._prepare_weights(
Expand Down Expand Up @@ -991,7 +997,7 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []):
# Map vllm's names to transformers' names.
# Map vllm's names to transformers's names.
for sub_name in sub_modules:
self.target_modules.append(
name.replace(last_name, sub_name))
Expand All @@ -1013,6 +1019,10 @@ def _load_weights(self, model_config: ModelConfig,
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.")

# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
Expand Down
90 changes: 65 additions & 25 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,30 +461,71 @@ def forward(
return output


class MolmoMLP(nn.Module):
class SwiGLU(nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
# Note that the order is reversed compared to
# SiluAndMul.
return x * F.silu(gate)
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved


class LanuageModelMLP(nn.Module):
"""Molmo's LLM mlp."""

def __init__(self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
proj_name: str = "gate_up_proj") -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2

# Molmo's LLM proj weights are already merged into the disk, while
# image_projector proj is separate. If the same proj_name were used, it
# would create ambiguity and make it difficult to support BNB and LoRA.
self.proj_name = proj_name
setattr(
self, proj_name,
MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
))
self.gate_up_proj = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SwiGLU()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)

def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class ImageProjectorMLP(nn.Module):
"""Molmo's image_projector mlp."""

def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2

self.merged_linear = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SiluAndMul()

Expand All @@ -500,7 +541,7 @@ def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = getattr(self, self.proj_name)(x)
gate_up, _ = self.merged_linear(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
Expand All @@ -523,9 +564,7 @@ def __init__(
prefix=f"{prefix}.self_attn")

# MLP block.
self.mlp = MolmoMLP(config,
quant_config=quant_config,
proj_name="gate_up_proj")
self.mlp = LanuageModelMLP(config, quant_config=quant_config)

# LayerNorm
assert config.layer_norm_type == "rms"
Expand Down Expand Up @@ -617,11 +656,10 @@ def __init__(
vision_config,
nlayers=len(self.vit_layers),
quant_config=quant_config)
self.image_projector = MolmoMLP(
self.image_projector = ImageProjectorMLP(
config,
input_dim=vision_config.image_emb_dim,
quant_config=quant_config,
proj_name="merged_linear",
)

image_dim = vision_config.image_emb_dim * len(self.vit_layers)
Expand Down Expand Up @@ -842,10 +880,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
loaded_params: Set[str] = set()

for name, loaded_weight in weights:
if "gate_up_proj" in name:
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)

if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
Expand Down Expand Up @@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
},
)

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0),
"up_proj": ("merged_linear", 1),
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down
Loading