Skip to content

Commit

Permalink
[Misc][LoRA] Support Rank Stabilized LoRA (RSLoRA) (#6909)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
  • Loading branch information
JohnGiorgi and jeejeelee authored Dec 31, 2024
1 parent 74fa1d1 commit 82c49d3
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 22 deletions.
20 changes: 13 additions & 7 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import math
import os
from typing import Dict, List

Expand Down Expand Up @@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files):
"embed_tokens",
"lm_head",
]
scaling = peft_helper.lora_alpha / peft_helper.r
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

# test RSLoRA
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
peft_helper = PEFTHelper.from_dict(config)

scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

expected_error = "vLLM only supports modules_to_save being None."
with pytest.raises(ValueError, match=expected_error):
Expand All @@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files):
modules_to_save=["lm_head"],
)
PEFTHelper.from_dict(config)
expected_error = "vLLM does not yet support RSLoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
PEFTHelper.from_dict(config)

expected_error = "vLLM does not yet support DoRA."
with pytest.raises(ValueError, match=expected_error):
Expand Down
12 changes: 3 additions & 9 deletions vllm/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,9 @@ def from_config(
peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights":
return cls(
module_name,
peft_helper.r,
peft_helper.lora_alpha,
None,
None,
None,
embeddings_tensor,
)
return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
None, None, embeddings_tensor,
peft_helper.vllm_lora_scaling_factor)

@classmethod
def create_dummy_lora_weights(
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def from_lora_tensors(
return cls(lora_model_id,
peft_helper.r,
loras,
scaling_factor=peft_helper.vllm_scaling_factor)
scaling_factor=peft_helper.vllm_long_context_scaling_factor)

@classmethod
def from_local_checkpoint(
Expand Down
18 changes: 13 additions & 5 deletions vllm/lora/peft_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union

from vllm.utils import print_info_once


@dataclass
class PEFTHelper:
Expand All @@ -14,21 +16,22 @@ class PEFTHelper:

bias: Literal["none", "all", "lora_only"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None)
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False)
# long lora field
# long context lora field
context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_scaling_factor: Optional[float] = field(default=None)
vllm_long_context_scaling_factor: Optional[float] = field(default=None)

def _validate_features(self):
error_msg = []

if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_rslora:
error_msg.append("vLLM does not yet support RSLoRA.")

if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")
Expand All @@ -38,10 +41,15 @@ def _validate_features(self):

def __post_init__(self):
self._validate_features()
if self.use_rslora:
print_info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
if self.context_length:
if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length
self.vllm_scaling_factor = float(
self.vllm_long_context_scaling_factor = float(
math.ceil(self.context_length /
self.vllm_max_position_embeddings))

Expand Down

0 comments on commit 82c49d3

Please sign in to comment.