From e308f378f391ef05b61cf6a6eb93dc540f6d8f71 Mon Sep 17 00:00:00 2001 From: John Giorgi <john@abridge.com> Date: Mon, 29 Jul 2024 15:05:47 -0400 Subject: [PATCH 1/8] feat: support rslora --- vllm/lora/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e1ede7d4d710a..144d6a6d3d834 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -368,7 +368,7 @@ def from_local_checkpoint( embeddings = torch.load(new_embeddings_bin_file_path) rank = config["r"] - lora_alpha = config["lora_alpha"] + lora_alpha = config["lora_alpha"] * math.sqrt(rank) if config["use_rslora"] else config["lora_alpha"] context_length = config.get("context_length", None) scaling_factor = None if context_length: From 6690e8b1bda1b615c070c31cabaaaa2e6917b569 Mon Sep 17 00:00:00 2001 From: John Giorgi <john@abridge.com> Date: Mon, 29 Jul 2024 15:15:18 -0400 Subject: [PATCH 2/8] fix: lint the codebase --- vllm/lora/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 144d6a6d3d834..20f62a0eace36 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -368,7 +368,8 @@ def from_local_checkpoint( embeddings = torch.load(new_embeddings_bin_file_path) rank = config["r"] - lora_alpha = config["lora_alpha"] * math.sqrt(rank) if config["use_rslora"] else config["lora_alpha"] + lora_alpha = config["lora_alpha"] * math.sqrt( + rank) if config["use_rslora"] else config["lora_alpha"] context_length = config.get("context_length", None) scaling_factor = None if context_length: From e2cf4205845a398448f6d3649e133954bb181929 Mon Sep 17 00:00:00 2001 From: John Giorgi <john@abridge.com> Date: Mon, 29 Jul 2024 16:54:19 -0400 Subject: [PATCH 3/8] fix: default to not using rslora --- vllm/lora/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 20f62a0eace36..ce8fd65a8d5da 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -368,8 +368,8 @@ def from_local_checkpoint( embeddings = torch.load(new_embeddings_bin_file_path) rank = config["r"] - lora_alpha = config["lora_alpha"] * math.sqrt( - rank) if config["use_rslora"] else config["lora_alpha"] + lora_alpha = config["lora_alpha"] * math.sqrt(rank) if config.get( + "use_rslora", False) else config["lora_alpha"] context_length = config.get("context_length", None) scaling_factor = None if context_length: From a3206bba5483087bba89bb9cbce6ccf6af59a9a7 Mon Sep 17 00:00:00 2001 From: John Giorgi <john@abridge.com> Date: Mon, 23 Dec 2024 09:38:44 -0500 Subject: [PATCH 4/8] fix: move rslora scaling to peft_helper --- vllm/lora/peft_helper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index edf4ba5659575..e487ffe13fd2c 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -38,6 +38,8 @@ def _validate_features(self): def __post_init__(self): self._validate_features() + if self.use_rslora: + self.lora_alpha = self.lora_alpha * math.sqrt(self.r) if self.context_length: if self.vllm_max_position_embeddings is None: self.vllm_max_position_embeddings = self.context_length From 42e04aab9835ab2f55777dcee9e4013ae9d41f1e Mon Sep 17 00:00:00 2001 From: John Giorgi <john@abridge.com> Date: Mon, 23 Dec 2024 09:43:01 -0500 Subject: [PATCH 5/8] fix: set scaling factor directly, instead of modifying alpha --- vllm/lora/peft_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index e487ffe13fd2c..3c79ae5f65873 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -39,7 +39,7 @@ def _validate_features(self): def __post_init__(self): self._validate_features() if self.use_rslora: - self.lora_alpha = self.lora_alpha * math.sqrt(self.r) + self.vllm_scaling_factor = self.lora_alpha / math.sqrt(self.r) if self.context_length: if self.vllm_max_position_embeddings is None: self.vllm_max_position_embeddings = self.context_length From f73a3db7857ac0b26723e5167d65e5adb7fac49d Mon Sep 17 00:00:00 2001 From: John Giorgi <john@abridge.com> Date: Mon, 23 Dec 2024 09:50:04 -0500 Subject: [PATCH 6/8] fix: remove error message about RSLoRA not being supported --- vllm/lora/peft_helper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 3c79ae5f65873..5ddf8abdf0c2d 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -27,8 +27,6 @@ def _validate_features(self): if self.modules_to_save: error_msg.append("vLLM only supports modules_to_save being None.") - if self.use_rslora: - error_msg.append("vLLM does not yet support RSLoRA.") if self.use_dora: error_msg.append("vLLM does not yet support DoRA.") From 419a6fb4a477165234bd68fd39ed3cf990b20bb7 Mon Sep 17 00:00:00 2001 From: John Giorgi <john@abridge.com> Date: Mon, 23 Dec 2024 09:55:15 -0500 Subject: [PATCH 7/8] docs: add comments with arxiv links for rsLoRA and DoRA --- vllm/lora/peft_helper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 5ddf8abdf0c2d..820ebc94016ea 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -14,7 +14,9 @@ class PEFTHelper: bias: Literal["none", "all", "lora_only"] = field(default="none") modules_to_save: Optional[list[str]] = field(default=None) + # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) use_rslora: bool = field(default=False) + # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) use_dora: bool = field(default=False) # long lora field context_length: int = field(default=0) From 7ace76c7ae4132615e63c18d2ef6965879ea10e5 Mon Sep 17 00:00:00 2001 From: Jee Jee Li <pandaleefree@gmail.com> Date: Tue, 24 Dec 2024 08:15:54 +0000 Subject: [PATCH 8/8] Done Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> --- tests/lora/test_lora_manager.py | 20 +++++++++++++------- vllm/lora/lora.py | 12 +++--------- vllm/lora/models.py | 2 +- vllm/lora/peft_helper.py | 14 ++++++++++---- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 0b76f466702fc..a099f36b0a465 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -1,4 +1,5 @@ import json +import math import os from typing import Dict, List @@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files): "embed_tokens", "lm_head", ] + scaling = peft_helper.lora_alpha / peft_helper.r + assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 + + # test RSLoRA + config = dict(r=8, + lora_alpha=16, + target_modules=["gate_proj"], + use_rslora=True) + peft_helper = PEFTHelper.from_dict(config) + + scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r) + assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 expected_error = "vLLM only supports modules_to_save being None." with pytest.raises(ValueError, match=expected_error): @@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files): modules_to_save=["lm_head"], ) PEFTHelper.from_dict(config) - expected_error = "vLLM does not yet support RSLoRA." - with pytest.raises(ValueError, match=expected_error): - config = dict(r=8, - lora_alpha=16, - target_modules=["gate_proj"], - use_rslora=True) - PEFTHelper.from_dict(config) expected_error = "vLLM does not yet support DoRA." with pytest.raises(ValueError, match=expected_error): diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index dde347b78bf81..93ad4651f4b77 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -67,15 +67,9 @@ def from_config( peft_helper: PEFTHelper, embeddings_tensor: Optional[torch.Tensor] = None, ) -> "LoRALayerWeights": - return cls( - module_name, - peft_helper.r, - peft_helper.lora_alpha, - None, - None, - None, - embeddings_tensor, - ) + return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None, + None, None, embeddings_tensor, + peft_helper.vllm_lora_scaling_factor) @classmethod def create_dummy_lora_weights( diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 70806a77b9fff..b60da8bbfdbb0 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -172,7 +172,7 @@ def from_lora_tensors( return cls(lora_model_id, peft_helper.r, loras, - scaling_factor=peft_helper.vllm_scaling_factor) + scaling_factor=peft_helper.vllm_long_context_scaling_factor) @classmethod def from_local_checkpoint( diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 820ebc94016ea..ddd42ae93d290 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -4,6 +4,8 @@ from dataclasses import MISSING, dataclass, field, fields from typing import Literal, Optional, Union +from vllm.utils import print_info_once + @dataclass class PEFTHelper: @@ -18,11 +20,12 @@ class PEFTHelper: use_rslora: bool = field(default=False) # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) use_dora: bool = field(default=False) - # long lora field + # long context lora field context_length: int = field(default=0) # Extra vllm field, start with 'vllm_' to avoid conflict + vllm_lora_scaling_factor: float = field(default=1.0) vllm_max_position_embeddings: Optional[int] = field(default=False) - vllm_scaling_factor: Optional[float] = field(default=None) + vllm_long_context_scaling_factor: Optional[float] = field(default=None) def _validate_features(self): error_msg = [] @@ -39,11 +42,14 @@ def _validate_features(self): def __post_init__(self): self._validate_features() if self.use_rslora: - self.vllm_scaling_factor = self.lora_alpha / math.sqrt(self.r) + print_info_once("Loading LoRA weights trained with rsLoRA.") + self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) + else: + self.vllm_lora_scaling_factor = self.lora_alpha / self.r if self.context_length: if self.vllm_max_position_embeddings is None: self.vllm_max_position_embeddings = self.context_length - self.vllm_scaling_factor = float( + self.vllm_long_context_scaling_factor = float( math.ceil(self.context_length / self.vllm_max_position_embeddings))