Skip to content

Commit

Permalink
[MODEL] LoRA support for Jamba model (#11209)
Browse files Browse the repository at this point in the history
Signed-off-by: Erez Schwartz <[email protected]>
  • Loading branch information
ErezSC42 authored Dec 27, 2024
1 parent 1014180 commit 55509c2
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 32 deletions.
24 changes: 24 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import MagicMock, patch

import pytest
import safetensors
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")


@pytest.fixture(scope="session")
def jamba_lora_files():
# some of the adapters have unnecessary weights for serving,
# hence we remove them
def remove_unnecessary_weights(path):
lora_path = f"{adapter_path}/adapter_model.safetensors"
tensors = safetensors.torch.load_file(lora_path)
nonlora_keys = []
for k in list(tensors.keys()):
if "lora" not in k:
nonlora_keys.append(k)
for k in nonlora_keys:
del tensors[k]
safetensors.torch.save_file(tensors, lora_path)

adapter_path = snapshot_download(
repo_id=
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")

remove_unnecessary_weights(adapter_path)
return adapter_path


@pytest.fixture(scope="session")
def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
Expand Down
54 changes: 54 additions & 0 deletions tests/lora/test_jamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List

import pytest
import torch

import vllm
from vllm.lora.request import LoRARequest

MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"

MAX_TOKENS = 40


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: List[str]) -> List[str]:

sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.mark.parametrize("tp_size", [4])
def test_jamba_lora(jamba_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

prompts = ["Write a story about a sheep and a goat."]

llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
)

expected_jamba_output = [
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
]
assert do_sample(llm, jamba_lora_files, lora_id=1,
prompts=prompts) == expected_jamba_output
22 changes: 18 additions & 4 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ def __init__(self,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu"):
activation="silu",
is_lora_enabled: bool = False):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation
self.is_lora_enabled = is_lora_enabled

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
Expand All @@ -63,6 +65,7 @@ def __init__(self,
self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)

# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
intermediate_size,
Expand Down Expand Up @@ -170,7 +173,13 @@ def forward_cuda(self, hidden_states: torch.Tensor,

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

time_step, B, C = torch.split(
ssm_parameters,
Expand Down Expand Up @@ -222,6 +231,11 @@ def forward_cuda(self, hidden_states: torch.Tensor,
scan_outputs = scan_outputs.transpose(0, 1)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states
50 changes: 26 additions & 24 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ def __init__(self,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
is_lora_enabled: Optional[bool] = False,
**kwargs) -> None:
super().__init__()
self.config = config
self.is_lora_enabled = is_lora_enabled
self.mamba = MambaMixer(hidden_size= config.hidden_size,
ssm_state_size = config.mamba_d_state,
conv_kernel_size = config.mamba_d_conv,
Expand All @@ -120,7 +122,9 @@ def __init__(self,
use_bias = config.mamba_proj_bias,
use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act)
activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled
)

num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
Expand Down Expand Up @@ -156,14 +160,13 @@ def forward(

class JambaAttentionDecoderLayer(nn.Module):

def __init__(
self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -287,17 +290,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
org_num_embeddings=config.vocab_size,
)

extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}

def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]]
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
return layer_class(config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
**extra_kwargs)

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
Expand Down Expand Up @@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"k_proj",
"v_proj",
],
"in_proj": ["in_proj"],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
Expand Down Expand Up @@ -423,9 +426,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
Expand All @@ -446,7 +449,6 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False) -> None:
super().__init__()
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.is_lora_enabled = is_lora_enabled
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size,
Expand All @@ -53,7 +55,8 @@ def __init__(self,
use_rms_norm=self.is_falcon_mamba,
rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act)
activation=config.hidden_act,
is_lora_enabled=self.is_lora_enabled)

self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

Expand Down Expand Up @@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
is_lora_enabled = bool(lora_config)

self.config = config
self.padding_idx = config.pad_token_id
Expand All @@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MambaDecoderLayer(
config, cache_config=cache_config, quant_config=quant_config),
lambda prefix: MambaDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config,
is_lora_enabled=is_lora_enabled),
prefix=f"{prefix}.layers")

self.norm_f = RMSNorm(config.hidden_size,
Expand Down

0 comments on commit 55509c2

Please sign in to comment.