From 0b276b30e1459cbef379800fc59ed842401b0e71 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Sun, 15 Dec 2024 11:47:51 +0200 Subject: [PATCH 01/11] changes to support lora with jamba Signed-off-by: Erez Schwartz --- .../layers/mamba/mamba_mixer.py | 36 +++++-- vllm/model_executor/models/jamba.py | 96 +++++++++++-------- vllm/model_executor/models/mamba.py | 66 ++++++++----- 3 files changed, 127 insertions(+), 71 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 10bec75f49fdf..a3bb7c97f51ea 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -8,7 +8,6 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) @@ -42,12 +41,14 @@ def __init__(self, use_rms_norm: bool, rms_norm_has_weight: bool = True, rms_norm_eps: float = 1e-5, - activation="silu"): + activation="silu", + is_lora_enabled: bool = False): super().__init__() self.time_step_rank = time_step_rank self.ssm_state_size = ssm_state_size self.use_rms_norm = use_rms_norm self.activation = activation + self.is_lora_enabled = is_lora_enabled self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, @@ -60,9 +61,13 @@ def __init__(self, # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = MergedColumnParallelLinear(hidden_size, - [intermediate_size] * 2, - bias=use_bias) + self.in_proj_lin = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=use_bias) + self.in_proj_gate = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=use_bias) + # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( intermediate_size, @@ -134,8 +139,8 @@ def forward_cuda(self, hidden_states: torch.Tensor, mamba_cache_params: MambaCacheParams): # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) + gate = self.in_proj_gate(hidden_states)[0].transpose(-2, -1) + hidden_states = self.in_proj_lin(hidden_states)[0].transpose(-2, -1) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), @@ -170,7 +175,13 @@ def forward_cuda(self, hidden_states: torch.Tensor, # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + ssm_parameters = self.x_proj( + hidden_states.transpose(-2, -1).contiguous())[0] + else: + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] time_step, B, C = torch.split( ssm_parameters, @@ -222,6 +233,11 @@ def forward_cuda(self, hidden_states: torch.Tensor, scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1).contiguous())[0] + else: + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1))[0] return contextualized_states diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 831db2ae52d74..723e983eb5705 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE @@ -105,9 +105,11 @@ def __init__(self, layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + is_lora_enabled: Optional[bool] = False, + **kwargs) -> None: super().__init__() self.config = config + self.is_lora_enabled = is_lora_enabled self.mamba = MambaMixer(hidden_size= config.hidden_size, ssm_state_size = config.mamba_d_state, conv_kernel_size = config.mamba_d_conv, @@ -118,7 +120,9 @@ def __init__(self, use_bias = config.mamba_proj_bias, use_rms_norm=True, rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act) + activation=config.hidden_act, + is_lora_enabled = self.is_lora_enabled + ) num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP @@ -154,14 +158,13 @@ def forward( class JambaAttentionDecoderLayer(nn.Module): - def __init__( - self, - config: JambaConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + **kwargs) -> None: super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -285,17 +288,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) + extra_kwargs = {"is_lora_enabled": bool(vllm_config)} + def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ config.layers_block_type[layer_idx]] - return layer_class( - config, - layer_idx, - cache_config, - quant_config=quant_config, - prefix=prefix, - ) + return layer_class(config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + **extra_kwargs) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") @@ -373,10 +377,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", + "qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj", + "down_proj", "gate_proj", "out_proj", "in_proj", "x_proj" ] embedding_modules = { "embed_tokens": "input_embeddings", @@ -420,17 +422,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: - if self.scheduler_config.max_num_seqs > \ - vllm_config.compilation_config.max_capture_size: - self.max_batch_size = \ - vllm_config.compilation_config.max_capture_size - else: - self.max_batch_size = vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - else: - self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -444,12 +435,15 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: + max_batch_size = (VllmConfig.get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, - self.max_batch_size, *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) ( mamba_cache_tensors, state_indices_tensor, @@ -579,10 +573,34 @@ def load_weights(self, weights: Iterable[Tuple[str, if is_pp_missing_parameter(name, self): continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if "in_proj" in name: + # To support LoRA, in_proj weight needs to be split to + # two separate tensors, and here we load it manually + # manually splits in_proj_lin and in_proj_gate + name_lin = name.replace("in_proj", "in_proj_lin") + name_gate = name.replace("in_proj", "in_proj_gate") + + # need to split the loaded weight of in_proj + param = params_dict[name_lin] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + weight_loader(param, + loaded_weight[:loaded_weight.shape[0] // + 2, :]) # the lin split + + param = params_dict[name_gate] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + weight_loader(param, + loaded_weight[loaded_weight.shape[0] // + 2:, :]) # the lin split + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 06c8d9723cd01..9bb5940124842 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module): def __init__(self, config: MambaConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False) -> None: super().__init__() self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" + self.is_lora_enabled = is_lora_enabled mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None self.mixer = MambaMixer(hidden_size=config.hidden_size, ssm_state_size=config.state_size, @@ -53,7 +55,8 @@ def __init__(self, use_rms_norm=self.is_falcon_mamba, rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, - activation=config.hidden_act) + activation=config.hidden_act, + is_lora_enabled=self.is_lora_enabled) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) self.config = config self.padding_idx = config.pad_token_id @@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MambaDecoderLayer( - config, cache_config=cache_config, quant_config=quant_config), + lambda prefix: MambaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config, + is_lora_enabled=is_lora_enabled), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -195,17 +201,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) - if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: - if self.scheduler_config.max_num_seqs > \ - vllm_config.compilation_config.max_capture_size: - self.max_batch_size = \ - vllm_config.compilation_config.max_capture_size - else: - self.max_batch_size = vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - else: - self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) @@ -219,11 +214,15 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: + max_batch_size = (VllmConfig.get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, - self.max_batch_size, *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) ( mamba_cache_tensors, @@ -288,9 +287,32 @@ def load_weights(self, weights: Iterable[Tuple[str, if is_pp_missing_parameter(name, self): continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if "in_proj" in name: + # To support LoRA, in_proj weight needs to be split to + # two separate tensors, and here we load it manually + + # manually splits in_proj_lin and in_proj_gate + name_lin = name.replace("in_proj", "in_proj_lin") + name_gate = name.replace("in_proj", "in_proj_gate") + + # need to split the loaded weight of in_proj + param = params_dict[name_lin] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + weight_loader(param, loaded_weight[:loaded_weight.shape[0] // + 2, :]) # the lin split + + param = params_dict[name_gate] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + weight_loader(param, loaded_weight[loaded_weight.shape[0] // + 2:, :]) # the lin split + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params From a80bf8d7c144ebb688270fe4f12fc0db945685b8 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Sun, 15 Dec 2024 18:23:26 +0200 Subject: [PATCH 02/11] fixes to jamba modeling after merge from main --- vllm/model_executor/models/jamba.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 723e983eb5705..474052203b357 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE @@ -25,6 +25,8 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata +# from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, +# _get_graph_batch_size) from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -422,6 +424,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + else: + self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -435,15 +448,11 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) - num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, - *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) ( mamba_cache_tensors, state_indices_tensor, @@ -589,6 +598,7 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_weight[:loaded_weight.shape[0] // 2, :]) # the lin split + loaded_params.add(name_lin) param = params_dict[name_gate] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -596,12 +606,14 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight[loaded_weight.shape[0] // 2:, :]) # the lin split + loaded_params.add(name_gate) else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) + if "in_proj" not in name: + loaded_params.add(name) return loaded_params From 623046d2af6d049d9b98cc35c48bc8242c803114 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Sun, 15 Dec 2024 18:29:50 +0200 Subject: [PATCH 03/11] removed commented code, formatting --- vllm/model_executor/models/jamba.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 474052203b357..97a70a427e491 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -25,8 +25,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata -# from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, -# _get_graph_batch_size) from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType From 8140c34b2851bbf827b646e152a909e98c80575f Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Sun, 15 Dec 2024 18:35:03 +0200 Subject: [PATCH 04/11] changes to mamba modeling to support as well --- vllm/model_executor/models/mamba.py | 39 ++++++++++++++++++----------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9bb5940124842..43db3049c504b 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -201,6 +201,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + else: + self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) @@ -214,15 +225,11 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) - num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, - *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) ( mamba_cache_tensors, @@ -290,8 +297,7 @@ def load_weights(self, weights: Iterable[Tuple[str, if "in_proj" in name: # To support LoRA, in_proj weight needs to be split to # two separate tensors, and here we load it manually - - # manually splits in_proj_lin and in_proj_gate + # manually splits in_proj_lin and in_proj_gate name_lin = name.replace("in_proj", "in_proj_lin") name_gate = name.replace("in_proj", "in_proj_gate") @@ -300,19 +306,24 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight[:loaded_weight.shape[0] // - 2, :]) # the lin split + weight_loader(param, + loaded_weight[:loaded_weight.shape[0] // + 2, :]) # the lin split + loaded_params.add(name_lin) param = params_dict[name_gate] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight[loaded_weight.shape[0] // - 2:, :]) # the lin split + weight_loader(param, + loaded_weight[loaded_weight.shape[0] // + 2:, :]) # the lin split + loaded_params.add(name_gate) else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) + if "in_proj" not in name: + loaded_params.add(name) return loaded_params From 00f9a14f3520d80db1b85ff96fd8eb2a8f273025 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Sun, 15 Dec 2024 18:42:29 +0200 Subject: [PATCH 05/11] foramtting --- vllm/model_executor/models/mamba.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 43db3049c504b..f784833766c42 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -306,18 +306,16 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, - loaded_weight[:loaded_weight.shape[0] // - 2, :]) # the lin split + weight_loader(param, loaded_weight[:loaded_weight.shape[0] // + 2, :]) # the lin split loaded_params.add(name_lin) param = params_dict[name_gate] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, - loaded_weight[loaded_weight.shape[0] // - 2:, :]) # the lin split + weight_loader(param, loaded_weight[loaded_weight.shape[0] // + 2:, :]) # the lin split loaded_params.add(name_gate) else: param = params_dict[name] From 1ccb594ec4bb0460d5ca9609a9551bc231a27206 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Mon, 16 Dec 2024 15:35:34 +0200 Subject: [PATCH 06/11] minor changes, added unittest for jamba lora Signed-off-by: Erez Schwartz --- tests/lora/test_jamba.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/lora/test_jamba.py diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py new file mode 100644 index 0000000000000..e69de29bb2d1d From 467c11f9fca1b577155c8293f214aa0077012d61 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Mon, 16 Dec 2024 15:36:14 +0200 Subject: [PATCH 07/11] minor changes, added unittest for jamba lora Signed-off-by: Erez Schwartz --- tests/lora/conftest.py | 24 ++++++++++++++++++ tests/lora/test_jamba.py | 54 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 29ecf37808205..73bc0a616b50b 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest +import safetensors import torch import torch.nn as nn from huggingface_hub import snapshot_download @@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules(): return snapshot_download(repo_id="dyang415/mixtral-lora-v0") +@pytest.fixture(scope="session") +def jamba_lora_files(): + # some of the adapters have unnecessary weights for serving, + # hence we remove them + def remove_unnecessary_weights(path): + lora_path = f"{adapter_path}/adapter_model.safetensors" + tensors = safetensors.torch.load_file(lora_path) + nonlora_keys = [] + for k in list(tensors.keys()): + if "lora" not in k: + nonlora_keys.append(k) + for k in nonlora_keys: + del tensors[k] + safetensors.torch.save_file(tensors, lora_path) + + adapter_path = snapshot_download( + repo_id= + "hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora") + + remove_unnecessary_weights(adapter_path) + return adapter_path + + @pytest.fixture(scope="session") def gemma_lora_files(): return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py index e69de29bb2d1d..67bb9ceed1c1f 100644 --- a/tests/lora/test_jamba.py +++ b/tests/lora/test_jamba.py @@ -0,0 +1,54 @@ +from typing import List + +import pytest +import torch + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini" + +MAX_TOKENS = 40 + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, + prompts: List[str]) -> List[str]: + + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [2]) +def test_jamba_lora(jamba_lora_files, tp_size): + """Original test, the LoRA model has the common target modules, not all""" + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + prompts = ["Write a story about a sheep and a goat."] + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + distributed_executor_backend="ray", + tensor_parallel_size=tp_size, + ) + + expected_jamba_output = [ + """Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle and kind-hearted sheep, always ready to help others""", # noqa: E501 + ] + assert do_sample(llm, jamba_lora_files, lora_id=1, + prompts=prompts) == expected_jamba_output From 3f47f5a1fbf49cce0765c331efe636090b493359 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Mon, 16 Dec 2024 16:19:35 +0200 Subject: [PATCH 08/11] minor fix to is_lora_enabled Signed-off-by: Erez Schwartz --- vllm/model_executor/models/jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 97a70a427e491..17a9d65c10869 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -288,7 +288,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) - extra_kwargs = {"is_lora_enabled": bool(vllm_config)} + extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) From 9e7ec54ea25b67e00ef58c70961e62fd2434608b Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Thu, 26 Dec 2024 14:35:37 +0200 Subject: [PATCH 09/11] fixes after suggestions from pr Signed-off-by: Erez Schwartz --- tests/lora/test_jamba.py | 4 +- .../layers/mamba/mamba_mixer.py | 14 +++---- vllm/model_executor/models/jamba.py | 38 +++---------------- vllm/model_executor/models/mamba.py | 35 +++-------------- 4 files changed, 19 insertions(+), 72 deletions(-) diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py index 67bb9ceed1c1f..6aa33926cb6b8 100644 --- a/tests/lora/test_jamba.py +++ b/tests/lora/test_jamba.py @@ -30,7 +30,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, return generated_texts -@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("tp_size", [4]) def test_jamba_lora(jamba_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" if torch.cuda.device_count() < tp_size: @@ -48,7 +48,7 @@ def test_jamba_lora(jamba_lora_files, tp_size): ) expected_jamba_output = [ - """Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle and kind-hearted sheep, always ready to help others""", # noqa: E501 + """Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501 ] assert do_sample(llm, jamba_lora_files, lora_id=1, prompts=prompts) == expected_jamba_output diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a3bb7c97f51ea..606c796d503cf 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -8,6 +8,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) @@ -61,12 +62,9 @@ def __init__(self, # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj_lin = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=use_bias) - self.in_proj_gate = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=use_bias) + self.in_proj = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, + bias=use_bias) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( @@ -139,8 +137,8 @@ def forward_cuda(self, hidden_states: torch.Tensor, mamba_cache_params: MambaCacheParams): # 1. Gated MLP's linear projection - gate = self.in_proj_gate(hidden_states)[0].transpose(-2, -1) - hidden_states = self.in_proj_lin(hidden_states)[0].transpose(-2, -1) + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 17a9d65c10869..c3c8afaa618e4 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -373,6 +373,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, "k_proj", "v_proj", ], + "in_proj": ["in_proj"], } # LoRA specific attributes @@ -580,38 +581,11 @@ def load_weights(self, weights: Iterable[Tuple[str, if is_pp_missing_parameter(name, self): continue - if "in_proj" in name: - # To support LoRA, in_proj weight needs to be split to - # two separate tensors, and here we load it manually - # manually splits in_proj_lin and in_proj_gate - name_lin = name.replace("in_proj", "in_proj_lin") - name_gate = name.replace("in_proj", "in_proj_gate") - - # need to split the loaded weight of in_proj - param = params_dict[name_lin] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - - weight_loader(param, - loaded_weight[:loaded_weight.shape[0] // - 2, :]) # the lin split - - loaded_params.add(name_lin) - param = params_dict[name_gate] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - - weight_loader(param, - loaded_weight[loaded_weight.shape[0] // - 2:, :]) # the lin split - loaded_params.add(name_gate) - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - if "in_proj" not in name: - loaded_params.add(name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f784833766c42..553bc9c28cb21 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -294,34 +294,9 @@ def load_weights(self, weights: Iterable[Tuple[str, if is_pp_missing_parameter(name, self): continue - if "in_proj" in name: - # To support LoRA, in_proj weight needs to be split to - # two separate tensors, and here we load it manually - # manually splits in_proj_lin and in_proj_gate - name_lin = name.replace("in_proj", "in_proj_lin") - name_gate = name.replace("in_proj", "in_proj_gate") - - # need to split the loaded weight of in_proj - param = params_dict[name_lin] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - - weight_loader(param, loaded_weight[:loaded_weight.shape[0] // - 2, :]) # the lin split - - loaded_params.add(name_lin) - param = params_dict[name_gate] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - - weight_loader(param, loaded_weight[loaded_weight.shape[0] // - 2:, :]) # the lin split - loaded_params.add(name_gate) - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - if "in_proj" not in name: - loaded_params.add(name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) return loaded_params From 492bcf346f2f58e60f8740abb177aac436bcc204 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Thu, 26 Dec 2024 15:29:22 +0200 Subject: [PATCH 10/11] fixes after suggestions from pr Signed-off-by: Erez Schwartz --- tests/lora/test_jamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py index 6aa33926cb6b8..7b0af6787337c 100644 --- a/tests/lora/test_jamba.py +++ b/tests/lora/test_jamba.py @@ -47,6 +47,7 @@ def test_jamba_lora(jamba_lora_files, tp_size): tensor_parallel_size=tp_size, ) + expected_jamba_output = [ """Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501 ] From cf7bfa408559165ace2b50ac0b6182b493b327dc Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Thu, 26 Dec 2024 15:52:28 +0200 Subject: [PATCH 11/11] fixes after suggestions from pr Signed-off-by: Erez Schwartz --- tests/lora/test_jamba.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py index 7b0af6787337c..6aa33926cb6b8 100644 --- a/tests/lora/test_jamba.py +++ b/tests/lora/test_jamba.py @@ -47,7 +47,6 @@ def test_jamba_lora(jamba_lora_files, tp_size): tensor_parallel_size=tp_size, ) - expected_jamba_output = [ """Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501 ]