Skip to content

Commit

Permalink
[LoRA] feat: support loading loras into 4bit quantized Flux models. (#…
Browse files Browse the repository at this point in the history
…10578)

* feat: support loading loras into 4bit quantized models.

* updates

* update

* remove weight check.
  • Loading branch information
sayakpaul authored and DN6 committed Jan 15, 2025
1 parent a663a67 commit 263b973
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 4 deletions.
39 changes: 36 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_peft_available,
is_peft_version,
is_torch_version,
Expand Down Expand Up @@ -1981,10 +1982,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
in_features = state_dict[lora_A_weight_name].shape[1]
out_features = state_dict[lora_B_weight_name].shape[0]

# Model maybe loaded with different quantization schemes which may flatten the params.
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
# preserve weight shape.
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)

# This means there's no need for an expansion in the params, so we simply skip.
if tuple(module_weight.shape) == (out_features, in_features):
if tuple(module_weight_shape) == (out_features, in_features):
continue

# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
debug_message = ""
if in_features > module_in_features:
Expand Down Expand Up @@ -2080,13 +2088,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

if base_weight_param.shape[1] > lora_A_param.shape[1]:
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)

if base_module_shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
elif base_module_shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
Expand All @@ -2098,6 +2109,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):

return lora_state_dict

@staticmethod
def _calculate_module_shape(
model: "torch.nn.Module",
base_module: "torch.nn.Linear" = None,
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape

if base_module is not None:
return _get_weight_shape(base_module.weight)
elif base_weight_param_name is not None:
if not base_weight_param_name.endswith(".weight"):
raise ValueError(
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
)
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
submodule = get_submodule_by_name(model, module_path)
return _get_weight_shape(submodule.weight)

raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")


# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
is_xformers_available,
requires_backends,
)
from .loading_utils import get_module_from_name, load_image, load_video
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/utils/loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,15 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
module = new_module
tensor_name = splits[-1]
return module, tensor_name


def get_submodule_by_name(root_module, module_path: str):
current = root_module
parts = module_path.split(".")
for part in parts:
if part.isdigit():
idx = int(part)
current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
else:
current = getattr(current, part)
return current
22 changes: 22 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import is_accelerate_version, logging
Expand Down Expand Up @@ -568,6 +569,27 @@ def test_quality(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)

def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125)

output = self.pipeline_4bit(
prompt=self.prompt,
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])

max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)


@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
Expand Down

0 comments on commit 263b973

Please sign in to comment.