From e308f378f391ef05b61cf6a6eb93dc540f6d8f71 Mon Sep 17 00:00:00 2001
From: John Giorgi <john@abridge.com>
Date: Mon, 29 Jul 2024 15:05:47 -0400
Subject: [PATCH 1/8] feat: support rslora

---
 vllm/lora/models.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index e1ede7d4d710a..144d6a6d3d834 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -368,7 +368,7 @@ def from_local_checkpoint(
             embeddings = torch.load(new_embeddings_bin_file_path)
 
         rank = config["r"]
-        lora_alpha = config["lora_alpha"]
+        lora_alpha = config["lora_alpha"] * math.sqrt(rank) if config["use_rslora"] else config["lora_alpha"]
         context_length = config.get("context_length", None)
         scaling_factor = None
         if context_length:

From 6690e8b1bda1b615c070c31cabaaaa2e6917b569 Mon Sep 17 00:00:00 2001
From: John Giorgi <john@abridge.com>
Date: Mon, 29 Jul 2024 15:15:18 -0400
Subject: [PATCH 2/8] fix: lint the codebase

---
 vllm/lora/models.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index 144d6a6d3d834..20f62a0eace36 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -368,7 +368,8 @@ def from_local_checkpoint(
             embeddings = torch.load(new_embeddings_bin_file_path)
 
         rank = config["r"]
-        lora_alpha = config["lora_alpha"] * math.sqrt(rank) if config["use_rslora"] else config["lora_alpha"]
+        lora_alpha = config["lora_alpha"] * math.sqrt(
+            rank) if config["use_rslora"] else config["lora_alpha"]
         context_length = config.get("context_length", None)
         scaling_factor = None
         if context_length:

From e2cf4205845a398448f6d3649e133954bb181929 Mon Sep 17 00:00:00 2001
From: John Giorgi <john@abridge.com>
Date: Mon, 29 Jul 2024 16:54:19 -0400
Subject: [PATCH 3/8] fix: default to not using rslora

---
 vllm/lora/models.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index 20f62a0eace36..ce8fd65a8d5da 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -368,8 +368,8 @@ def from_local_checkpoint(
             embeddings = torch.load(new_embeddings_bin_file_path)
 
         rank = config["r"]
-        lora_alpha = config["lora_alpha"] * math.sqrt(
-            rank) if config["use_rslora"] else config["lora_alpha"]
+        lora_alpha = config["lora_alpha"] * math.sqrt(rank) if config.get(
+            "use_rslora", False) else config["lora_alpha"]
         context_length = config.get("context_length", None)
         scaling_factor = None
         if context_length:

From a3206bba5483087bba89bb9cbce6ccf6af59a9a7 Mon Sep 17 00:00:00 2001
From: John Giorgi <john@abridge.com>
Date: Mon, 23 Dec 2024 09:38:44 -0500
Subject: [PATCH 4/8] fix: move rslora scaling to peft_helper

---
 vllm/lora/peft_helper.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py
index edf4ba5659575..e487ffe13fd2c 100644
--- a/vllm/lora/peft_helper.py
+++ b/vllm/lora/peft_helper.py
@@ -38,6 +38,8 @@ def _validate_features(self):
 
     def __post_init__(self):
         self._validate_features()
+        if self.use_rslora:
+            self.lora_alpha = self.lora_alpha * math.sqrt(self.r)
         if self.context_length:
             if self.vllm_max_position_embeddings is None:
                 self.vllm_max_position_embeddings = self.context_length

From 42e04aab9835ab2f55777dcee9e4013ae9d41f1e Mon Sep 17 00:00:00 2001
From: John Giorgi <john@abridge.com>
Date: Mon, 23 Dec 2024 09:43:01 -0500
Subject: [PATCH 5/8] fix: set scaling factor directly, instead of modifying
 alpha

---
 vllm/lora/peft_helper.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py
index e487ffe13fd2c..3c79ae5f65873 100644
--- a/vllm/lora/peft_helper.py
+++ b/vllm/lora/peft_helper.py
@@ -39,7 +39,7 @@ def _validate_features(self):
     def __post_init__(self):
         self._validate_features()
         if self.use_rslora:
-            self.lora_alpha = self.lora_alpha * math.sqrt(self.r)
+            self.vllm_scaling_factor = self.lora_alpha / math.sqrt(self.r)
         if self.context_length:
             if self.vllm_max_position_embeddings is None:
                 self.vllm_max_position_embeddings = self.context_length

From f73a3db7857ac0b26723e5167d65e5adb7fac49d Mon Sep 17 00:00:00 2001
From: John Giorgi <john@abridge.com>
Date: Mon, 23 Dec 2024 09:50:04 -0500
Subject: [PATCH 6/8] fix: remove error message about RSLoRA not being
 supported

---
 vllm/lora/peft_helper.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py
index 3c79ae5f65873..5ddf8abdf0c2d 100644
--- a/vllm/lora/peft_helper.py
+++ b/vllm/lora/peft_helper.py
@@ -27,8 +27,6 @@ def _validate_features(self):
 
         if self.modules_to_save:
             error_msg.append("vLLM only supports modules_to_save being None.")
-        if self.use_rslora:
-            error_msg.append("vLLM does not yet support RSLoRA.")
 
         if self.use_dora:
             error_msg.append("vLLM does not yet support DoRA.")

From 419a6fb4a477165234bd68fd39ed3cf990b20bb7 Mon Sep 17 00:00:00 2001
From: John Giorgi <john@abridge.com>
Date: Mon, 23 Dec 2024 09:55:15 -0500
Subject: [PATCH 7/8] docs: add comments with arxiv links for rsLoRA and DoRA

---
 vllm/lora/peft_helper.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py
index 5ddf8abdf0c2d..820ebc94016ea 100644
--- a/vllm/lora/peft_helper.py
+++ b/vllm/lora/peft_helper.py
@@ -14,7 +14,9 @@ class PEFTHelper:
 
     bias: Literal["none", "all", "lora_only"] = field(default="none")
     modules_to_save: Optional[list[str]] = field(default=None)
+    # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
     use_rslora: bool = field(default=False)
+    # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
     use_dora: bool = field(default=False)
     # long lora field
     context_length: int = field(default=0)

From 7ace76c7ae4132615e63c18d2ef6965879ea10e5 Mon Sep 17 00:00:00 2001
From: Jee Jee Li <pandaleefree@gmail.com>
Date: Tue, 24 Dec 2024 08:15:54 +0000
Subject: [PATCH 8/8] Done

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
---
 tests/lora/test_lora_manager.py | 20 +++++++++++++-------
 vllm/lora/lora.py               | 12 +++---------
 vllm/lora/models.py             |  2 +-
 vllm/lora/peft_helper.py        | 14 ++++++++++----
 4 files changed, 27 insertions(+), 21 deletions(-)

diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py
index 0b76f466702fc..a099f36b0a465 100644
--- a/tests/lora/test_lora_manager.py
+++ b/tests/lora/test_lora_manager.py
@@ -1,4 +1,5 @@
 import json
+import math
 import os
 from typing import Dict, List
 
@@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files):
         "embed_tokens",
         "lm_head",
     ]
+    scaling = peft_helper.lora_alpha / peft_helper.r
+    assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
+
+    # test RSLoRA
+    config = dict(r=8,
+                  lora_alpha=16,
+                  target_modules=["gate_proj"],
+                  use_rslora=True)
+    peft_helper = PEFTHelper.from_dict(config)
+
+    scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
+    assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
 
     expected_error = "vLLM only supports modules_to_save being None."
     with pytest.raises(ValueError, match=expected_error):
@@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files):
             modules_to_save=["lm_head"],
         )
         PEFTHelper.from_dict(config)
-    expected_error = "vLLM does not yet support RSLoRA."
-    with pytest.raises(ValueError, match=expected_error):
-        config = dict(r=8,
-                      lora_alpha=16,
-                      target_modules=["gate_proj"],
-                      use_rslora=True)
-        PEFTHelper.from_dict(config)
 
     expected_error = "vLLM does not yet support DoRA."
     with pytest.raises(ValueError, match=expected_error):
diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py
index dde347b78bf81..93ad4651f4b77 100644
--- a/vllm/lora/lora.py
+++ b/vllm/lora/lora.py
@@ -67,15 +67,9 @@ def from_config(
         peft_helper: PEFTHelper,
         embeddings_tensor: Optional[torch.Tensor] = None,
     ) -> "LoRALayerWeights":
-        return cls(
-            module_name,
-            peft_helper.r,
-            peft_helper.lora_alpha,
-            None,
-            None,
-            None,
-            embeddings_tensor,
-        )
+        return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
+                   None, None, embeddings_tensor,
+                   peft_helper.vllm_lora_scaling_factor)
 
     @classmethod
     def create_dummy_lora_weights(
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index 70806a77b9fff..b60da8bbfdbb0 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -172,7 +172,7 @@ def from_lora_tensors(
         return cls(lora_model_id,
                    peft_helper.r,
                    loras,
-                   scaling_factor=peft_helper.vllm_scaling_factor)
+                   scaling_factor=peft_helper.vllm_long_context_scaling_factor)
 
     @classmethod
     def from_local_checkpoint(
diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py
index 820ebc94016ea..ddd42ae93d290 100644
--- a/vllm/lora/peft_helper.py
+++ b/vllm/lora/peft_helper.py
@@ -4,6 +4,8 @@
 from dataclasses import MISSING, dataclass, field, fields
 from typing import Literal, Optional, Union
 
+from vllm.utils import print_info_once
+
 
 @dataclass
 class PEFTHelper:
@@ -18,11 +20,12 @@ class PEFTHelper:
     use_rslora: bool = field(default=False)
     # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
     use_dora: bool = field(default=False)
-    # long lora field
+    # long context lora field
     context_length: int = field(default=0)
     # Extra vllm field, start with 'vllm_' to avoid conflict
+    vllm_lora_scaling_factor: float = field(default=1.0)
     vllm_max_position_embeddings: Optional[int] = field(default=False)
-    vllm_scaling_factor: Optional[float] = field(default=None)
+    vllm_long_context_scaling_factor: Optional[float] = field(default=None)
 
     def _validate_features(self):
         error_msg = []
@@ -39,11 +42,14 @@ def _validate_features(self):
     def __post_init__(self):
         self._validate_features()
         if self.use_rslora:
-            self.vllm_scaling_factor = self.lora_alpha / math.sqrt(self.r)
+            print_info_once("Loading LoRA weights trained with rsLoRA.")
+            self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
+        else:
+            self.vllm_lora_scaling_factor = self.lora_alpha / self.r
         if self.context_length:
             if self.vllm_max_position_embeddings is None:
                 self.vllm_max_position_embeddings = self.context_length
-            self.vllm_scaling_factor = float(
+            self.vllm_long_context_scaling_factor = float(
                 math.ceil(self.context_length /
                           self.vllm_max_position_embeddings))