-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MODEL] Update LoRA modules supported by Jamba #11209
Changes from all commits
0b276b3
a80bf8d
623046d
8140c34
00f9a14
1ccb594
467c11f
3f47f5a
9e7ec54
492bcf3
cf7bfa4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import List | ||
|
||
import pytest | ||
import torch | ||
|
||
import vllm | ||
from vllm.lora.request import LoRARequest | ||
|
||
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini" | ||
|
||
MAX_TOKENS = 40 | ||
|
||
|
||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, | ||
prompts: List[str]) -> List[str]: | ||
|
||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS) | ||
outputs = llm.generate( | ||
prompts, | ||
sampling_params, | ||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) | ||
if lora_id else None) | ||
# Print the outputs. | ||
generated_texts: List[str] = [] | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text.strip() | ||
generated_texts.append(generated_text) | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
return generated_texts | ||
|
||
|
||
@pytest.mark.parametrize("tp_size", [4]) | ||
def test_jamba_lora(jamba_lora_files, tp_size): | ||
"""Original test, the LoRA model has the common target modules, not all""" | ||
if torch.cuda.device_count() < tp_size: | ||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") | ||
|
||
prompts = ["Write a story about a sheep and a goat."] | ||
|
||
llm = vllm.LLM( | ||
MODEL_PATH, | ||
enable_lora=True, | ||
max_num_seqs=16, | ||
max_loras=4, | ||
distributed_executor_backend="ray", | ||
tensor_parallel_size=tp_size, | ||
) | ||
|
||
expected_jamba_output = [ | ||
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501 | ||
] | ||
assert do_sample(llm, jamba_lora_files, lora_id=1, | ||
prompts=prompts) == expected_jamba_output |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module): | |
def __init__(self, | ||
config: MambaConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None) -> None: | ||
quant_config: Optional[QuantizationConfig] = None, | ||
is_lora_enabled: Optional[bool] = False) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.is_falcon_mamba = config.model_type == "falcon_mamba" | ||
self.is_lora_enabled = is_lora_enabled | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will you implement Lora support for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since I had to do changes in MambaMixer that is being used by both Jamba and Mamba, I made the required changes for Mamba as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @DarkLight1337 I think this PR just want to update Jamba lora module |
||
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None | ||
self.mixer = MambaMixer(hidden_size=config.hidden_size, | ||
ssm_state_size=config.state_size, | ||
|
@@ -53,7 +55,8 @@ def __init__(self, | |
use_rms_norm=self.is_falcon_mamba, | ||
rms_norm_has_weight=not self.is_falcon_mamba, | ||
rms_norm_eps=mixer_rms_eps, | ||
activation=config.hidden_act) | ||
activation=config.hidden_act, | ||
is_lora_enabled=self.is_lora_enabled) | ||
|
||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) | ||
|
||
|
@@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
cache_config = vllm_config.cache_config | ||
quant_config = vllm_config.quant_config | ||
lora_config = vllm_config.lora_config | ||
is_lora_enabled = bool(lora_config) | ||
|
||
self.config = config | ||
self.padding_idx = config.pad_token_id | ||
|
@@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
|
||
self.start_layer, self.end_layer, self.layers = make_layers( | ||
config.num_hidden_layers, | ||
lambda prefix: MambaDecoderLayer( | ||
config, cache_config=cache_config, quant_config=quant_config), | ||
lambda prefix: MambaDecoderLayer(config, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
is_lora_enabled=is_lora_enabled), | ||
prefix=f"{prefix}.layers") | ||
|
||
self.norm_f = RMSNorm(config.hidden_size, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe using
vllm_config: VllmConfig
would be better, rather than adding another arg