Skip to content

Commit

Permalink
[LoRA] Support original format loras for HunyuanVideo (#10376)
Browse files Browse the repository at this point in the history
* update

* fix make copies

* update

* add relevant markers to the integration test suite.

* add copied.

* fox-copies

* temporarily add print.

* directly place on CUDA as CPU isn't that big on the CIO.

* fixes to fuse_lora, aryan was right.

* fixes

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
2 people authored and DN6 committed Jan 15, 2025
1 parent b3d2dd3 commit 526858c
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 6 deletions.
175 changes: 175 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,3 +973,178 @@ def swap_scale_shift(weight):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict


def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}

def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight

def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key

if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)

def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v

def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v

def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072

if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
linear1_weight = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp

elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
linear1_bias = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias

else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)

TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}

# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
# sure that both follow the same initial format by stripping off the "transformer." prefix.
for key in list(converted_state_dict.keys()):
if key.startswith("transformer."):
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
if key.startswith("diffusion_model."):
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)

# Rename and remap the state dict keys
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)

for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)

# Add back the "transformer." prefix
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict
14 changes: 8 additions & 6 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
Expand Down Expand Up @@ -4007,7 +4008,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):

@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
Expand All @@ -4018,7 +4018,7 @@ def lora_state_dict(
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
We support loading original format HunyuanVideo LoRA checkpoints.
This function is experimental and might change in the future.
Expand Down Expand Up @@ -4101,6 +4101,10 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
if is_original_hunyuan_video:
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)

return state_dict

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
Expand Down Expand Up @@ -4239,10 +4243,9 @@ def save_lora_weights(
safe_serialization=safe_serialization,
)

# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
Expand Down Expand Up @@ -4283,8 +4286,7 @@ def fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)

# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
Expand Down
73 changes: 73 additions & 0 deletions tests/lora/test_lora_layers_hunyuanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import sys
import unittest

import numpy as np
import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast

Expand All @@ -26,7 +29,11 @@
)
from diffusers.utils.testing_utils import (
floats_tensor,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_peft_backend,
require_torch_gpu,
skip_mps,
)

Expand Down Expand Up @@ -182,3 +189,69 @@ def test_simple_inference_with_text_lora_fused(self):
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass


@nightly
@require_torch_gpu
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
assertions to pass.
"""

num_inference_steps = 10
seed = 0

def setUp(self):
super().setUp()

gc.collect()
torch.cuda.empty_cache()

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
self.pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.float16
).to("cuda")

def tearDown(self):
super().tearDown()

gc.collect()
torch.cuda.empty_cache()

def test_original_format_cseti(self):
self.pipeline.load_lora_weights(
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.vae.enable_tiling()

prompt = "CSETIARCANE. A cat walks on the grass, realistic"

out = self.pipeline(
prompt=prompt,
height=320,
width=512,
num_frames=9,
num_inference_steps=self.num_inference_steps,
output_type="np",
generator=torch.manual_seed(self.seed),
).frames[0]
out = out.flatten()
out_slice = np.concatenate((out[:8], out[-8:]))

# fmt: off
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
# fmt: on

max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

0 comments on commit 526858c

Please sign in to comment.