From 305f3cb0ffc1f26d36e79b55f188075735ca5314 Mon Sep 17 00:00:00 2001 From: ErezSC42 Date: Fri, 27 Dec 2024 19:58:21 +0200 Subject: [PATCH] [MODEL] LoRA support for Jamba model (#11209) Signed-off-by: Erez Schwartz Signed-off-by: xcnick --- tests/lora/conftest.py | 24 +++++++++ tests/lora/test_jamba.py | 54 +++++++++++++++++++ .../layers/mamba/mamba_mixer.py | 22 ++++++-- vllm/model_executor/models/jamba.py | 50 ++++++++--------- vllm/model_executor/models/mamba.py | 14 +++-- 5 files changed, 132 insertions(+), 32 deletions(-) create mode 100644 tests/lora/test_jamba.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 8b247fb9b2388..57ebaa424fc59 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest +import safetensors import torch import torch.nn as nn from huggingface_hub import snapshot_download @@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules(): return snapshot_download(repo_id="dyang415/mixtral-lora-v0") +@pytest.fixture(scope="session") +def jamba_lora_files(): + # some of the adapters have unnecessary weights for serving, + # hence we remove them + def remove_unnecessary_weights(path): + lora_path = f"{adapter_path}/adapter_model.safetensors" + tensors = safetensors.torch.load_file(lora_path) + nonlora_keys = [] + for k in list(tensors.keys()): + if "lora" not in k: + nonlora_keys.append(k) + for k in nonlora_keys: + del tensors[k] + safetensors.torch.save_file(tensors, lora_path) + + adapter_path = snapshot_download( + repo_id= + "hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora") + + remove_unnecessary_weights(adapter_path) + return adapter_path + + @pytest.fixture(scope="session") def gemma_lora_files(): return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py new file mode 100644 index 0000000000000..6aa33926cb6b8 --- /dev/null +++ b/tests/lora/test_jamba.py @@ -0,0 +1,54 @@ +from typing import List + +import pytest +import torch + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini" + +MAX_TOKENS = 40 + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, + prompts: List[str]) -> List[str]: + + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [4]) +def test_jamba_lora(jamba_lora_files, tp_size): + """Original test, the LoRA model has the common target modules, not all""" + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + prompts = ["Write a story about a sheep and a goat."] + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + distributed_executor_backend="ray", + tensor_parallel_size=tp_size, + ) + + expected_jamba_output = [ + """Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501 + ] + assert do_sample(llm, jamba_lora_files, lora_id=1, + prompts=prompts) == expected_jamba_output diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 10bec75f49fdf..606c796d503cf 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -42,12 +42,14 @@ def __init__(self, use_rms_norm: bool, rms_norm_has_weight: bool = True, rms_norm_eps: float = 1e-5, - activation="silu"): + activation="silu", + is_lora_enabled: bool = False): super().__init__() self.time_step_rank = time_step_rank self.ssm_state_size = ssm_state_size self.use_rms_norm = use_rms_norm self.activation = activation + self.is_lora_enabled = is_lora_enabled self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, @@ -63,6 +65,7 @@ def __init__(self, self.in_proj = MergedColumnParallelLinear(hidden_size, [intermediate_size] * 2, bias=use_bias) + # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( intermediate_size, @@ -170,7 +173,13 @@ def forward_cuda(self, hidden_states: torch.Tensor, # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + ssm_parameters = self.x_proj( + hidden_states.transpose(-2, -1).contiguous())[0] + else: + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] time_step, B, C = torch.split( ssm_parameters, @@ -222,6 +231,11 @@ def forward_cuda(self, hidden_states: torch.Tensor, scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1).contiguous())[0] + else: + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1))[0] return contextualized_states diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 91786db5ddc96..890b5530b97d6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -107,9 +107,11 @@ def __init__(self, layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + is_lora_enabled: Optional[bool] = False, + **kwargs) -> None: super().__init__() self.config = config + self.is_lora_enabled = is_lora_enabled self.mamba = MambaMixer(hidden_size= config.hidden_size, ssm_state_size = config.mamba_d_state, conv_kernel_size = config.mamba_d_conv, @@ -120,7 +122,9 @@ def __init__(self, use_bias = config.mamba_proj_bias, use_rms_norm=True, rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act) + activation=config.hidden_act, + is_lora_enabled = self.is_lora_enabled + ) num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP @@ -156,14 +160,13 @@ def forward( class JambaAttentionDecoderLayer(nn.Module): - def __init__( - self, - config: JambaConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + **kwargs) -> None: super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -287,17 +290,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) + extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} + def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ config.layers_block_type[layer_idx]] - return layer_class( - config, - layer_idx, - cache_config, - quant_config=quant_config, - prefix=prefix, - ) + return layer_class(config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + **extra_kwargs) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") @@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, "k_proj", "v_proj", ], + "in_proj": ["in_proj"], } # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", + "qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj", + "down_proj", "gate_proj", "out_proj", "in_proj", "x_proj" ] embedding_modules = { "embed_tokens": "input_embeddings", @@ -423,9 +426,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: + not self.model_config.enforce_eager: if self.scheduler_config.max_num_seqs > \ - vllm_config.compilation_config.max_capture_size: + vllm_config.compilation_config.max_capture_size: self.max_batch_size = \ vllm_config.compilation_config.max_capture_size else: @@ -446,7 +449,6 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 06c8d9723cd01..553bc9c28cb21 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module): def __init__(self, config: MambaConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False) -> None: super().__init__() self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" + self.is_lora_enabled = is_lora_enabled mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None self.mixer = MambaMixer(hidden_size=config.hidden_size, ssm_state_size=config.state_size, @@ -53,7 +55,8 @@ def __init__(self, use_rms_norm=self.is_falcon_mamba, rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, - activation=config.hidden_act) + activation=config.hidden_act, + is_lora_enabled=self.is_lora_enabled) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) self.config = config self.padding_idx = config.pad_token_id @@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MambaDecoderLayer( - config, cache_config=cache_config, quant_config=quant_config), + lambda prefix: MambaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config, + is_lora_enabled=is_lora_enabled), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size,