Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MODEL] Update LoRA modules supported by Jamba #11209

Merged
merged 11 commits into from
Dec 27, 2024
24 changes: 24 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import MagicMock, patch

import pytest
import safetensors
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")


@pytest.fixture(scope="session")
def jamba_lora_files():
# some of the adapters have unnecessary weights for serving,
# hence we remove them
def remove_unnecessary_weights(path):
lora_path = f"{adapter_path}/adapter_model.safetensors"
tensors = safetensors.torch.load_file(lora_path)
nonlora_keys = []
for k in list(tensors.keys()):
if "lora" not in k:
nonlora_keys.append(k)
for k in nonlora_keys:
del tensors[k]
safetensors.torch.save_file(tensors, lora_path)

adapter_path = snapshot_download(
repo_id=
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")

remove_unnecessary_weights(adapter_path)
return adapter_path


@pytest.fixture(scope="session")
def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
Expand Down
54 changes: 54 additions & 0 deletions tests/lora/test_jamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List

import pytest
import torch

import vllm
from vllm.lora.request import LoRARequest

MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"

MAX_TOKENS = 40


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: List[str]) -> List[str]:

sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.mark.parametrize("tp_size", [4])
def test_jamba_lora(jamba_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

prompts = ["Write a story about a sheep and a goat."]

llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
)

expected_jamba_output = [
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
]
assert do_sample(llm, jamba_lora_files, lora_id=1,
prompts=prompts) == expected_jamba_output
22 changes: 18 additions & 4 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ def __init__(self,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu"):
activation="silu",
is_lora_enabled: bool = False):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation
self.is_lora_enabled = is_lora_enabled

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
Expand All @@ -63,6 +65,7 @@ def __init__(self,
self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)

# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
intermediate_size,
Expand Down Expand Up @@ -170,7 +173,13 @@ def forward_cuda(self, hidden_states: torch.Tensor,

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

time_step, B, C = torch.split(
ssm_parameters,
Expand Down Expand Up @@ -222,6 +231,11 @@ def forward_cuda(self, hidden_states: torch.Tensor,
scan_outputs = scan_outputs.transpose(0, 1)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states
50 changes: 26 additions & 24 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def __init__(self,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
is_lora_enabled: Optional[bool] = False,
**kwargs) -> None:
super().__init__()
self.config = config
self.is_lora_enabled = is_lora_enabled
self.mamba = MambaMixer(hidden_size= config.hidden_size,
ssm_state_size = config.mamba_d_state,
conv_kernel_size = config.mamba_d_conv,
Expand All @@ -118,7 +120,9 @@ def __init__(self,
use_bias = config.mamba_proj_bias,
use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act)
activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled
)

num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
Expand Down Expand Up @@ -154,14 +158,13 @@ def forward(

class JambaAttentionDecoderLayer(nn.Module):

def __init__(
self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -285,17 +288,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
org_num_embeddings=config.vocab_size,
)

extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}

def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]]
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
return layer_class(config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
**extra_kwargs)

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
Expand Down Expand Up @@ -369,14 +373,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"k_proj",
"v_proj",
],
"in_proj": ["in_proj"],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
Expand Down Expand Up @@ -421,9 +424,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
Expand All @@ -444,7 +447,6 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe using vllm_config: VllmConfig would be better, rather than adding another arg

super().__init__()
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.is_lora_enabled = is_lora_enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you implement Lora support for MambaForCausalLM in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I had to do changes in MambaMixer that is being used by both Jamba and Mamba, I made the required changes for Mamba as well

Copy link
Collaborator

@jeejeelee jeejeelee Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 I think this PR just want to update Jamba lora module

mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size,
Expand All @@ -53,7 +55,8 @@ def __init__(self,
use_rms_norm=self.is_falcon_mamba,
rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act)
activation=config.hidden_act,
is_lora_enabled=self.is_lora_enabled)

self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

Expand Down Expand Up @@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
is_lora_enabled = bool(lora_config)

self.config = config
self.padding_idx = config.pad_token_id
Expand All @@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MambaDecoderLayer(
config, cache_config=cache_config, quant_config=quant_config),
lambda prefix: MambaDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config,
is_lora_enabled=is_lora_enabled),
prefix=f"{prefix}.layers")

self.norm_f = RMSNorm(config.hidden_size,
Expand Down
Loading