Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc][LoRA] Support Rank Stabilized LoRA (RSLoRA) #6909

Merged
merged 10 commits into from
Dec 31, 2024
20 changes: 13 additions & 7 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import math
import os
from typing import Dict, List

Expand Down Expand Up @@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files):
"embed_tokens",
"lm_head",
]
scaling = peft_helper.lora_alpha / peft_helper.r
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

# test RSLoRA
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
peft_helper = PEFTHelper.from_dict(config)

scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

expected_error = "vLLM only supports modules_to_save being None."
with pytest.raises(ValueError, match=expected_error):
Expand All @@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files):
modules_to_save=["lm_head"],
)
PEFTHelper.from_dict(config)
expected_error = "vLLM does not yet support RSLoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
PEFTHelper.from_dict(config)

expected_error = "vLLM does not yet support DoRA."
with pytest.raises(ValueError, match=expected_error):
Expand Down
12 changes: 3 additions & 9 deletions vllm/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,9 @@ def from_config(
peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights":
return cls(
module_name,
peft_helper.r,
peft_helper.lora_alpha,
None,
None,
None,
embeddings_tensor,
)
return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
None, None, embeddings_tensor,
peft_helper.vllm_lora_scaling_factor)

@classmethod
def create_dummy_lora_weights(
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def from_lora_tensors(
return cls(lora_model_id,
peft_helper.r,
loras,
scaling_factor=peft_helper.vllm_scaling_factor)
scaling_factor=peft_helper.vllm_long_context_scaling_factor)

@classmethod
def from_local_checkpoint(
Expand Down
18 changes: 13 additions & 5 deletions vllm/lora/peft_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union

from vllm.utils import print_info_once


@dataclass
class PEFTHelper:
Expand All @@ -14,21 +16,22 @@ class PEFTHelper:

bias: Literal["none", "all", "lora_only"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None)
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False)
# long lora field
# long context lora field
context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_scaling_factor: Optional[float] = field(default=None)
vllm_long_context_scaling_factor: Optional[float] = field(default=None)

def _validate_features(self):
error_msg = []

if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_rslora:
error_msg.append("vLLM does not yet support RSLoRA.")

if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")
Expand All @@ -38,10 +41,15 @@ def _validate_features(self):

def __post_init__(self):
self._validate_features()
if self.use_rslora:
print_info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
if self.context_length:
if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length
self.vllm_scaling_factor = float(
self.vllm_long_context_scaling_factor = float(
math.ceil(self.context_length /
self.vllm_max_position_embeddings))

Expand Down
Loading