Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Flux multiple Lora loading bug #10388

Merged
merged 10 commits into from
Jan 2, 2025

Conversation

maxs-kan
Copy link
Contributor

What does this PR do?

The current approach of checking for a key with a base_layer suffix may lead to a bug when multiple Lora models are loaded. If the first loaded Lora model does not have weights for layer n, and the second one does, loading the second model will lead to an error since the transformer state dict currently does not have key n.base_layer.weight. So I explicitly check for the presence of a key with the base_layer suffix.

@yiyixuxu

@a-r-r-o-w a-r-r-o-w requested a review from hlky December 26, 2024 12:25
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hlky
Copy link
Collaborator

hlky commented Dec 26, 2024

Hi @maxs-kan, thanks for your contribution, can you share some example lora checkpoints that may lead to a bug?

@maxs-kan
Copy link
Contributor Author

maxs-kan commented Dec 26, 2024

Sure, try in the same order:
pipe.load_lora_weights(hf_hub_download("TTPlanet/Migration_Lora_flux","Migration_Lora_cloth.safetensors"), adapter_name="cloth")
pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo")

Copy link
Collaborator

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code
from diffusers import FluxPipeline
from huggingface_hub import hf_hub_download
import torch

pipe = FluxPipeline.from_pretrained(
  "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(
  hf_hub_download("TTPlanet/Migration_Lora_flux", "Migration_Lora_cloth.safetensors"),
  adapter_name="cloth",
)
pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo")

@sayakpaul
Copy link
Member

@maxs-kan thanks for this PR. Do you want to also propagate the changes from #10396?

Comment on lines 2463 to 2465
transformer_base_layer_keys = {
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note base_layer substring can only be present when the underlying pipeline has at least one LoRA loaded that affects the layer under consideration. So, perhaps it's better to have an is_peft_loaded check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your PR description you mention:

If the first loaded Lora model does not have weights for layer n, and the second one does, loading the second model will lead to an error since the transformer state dict currently does not have key n.base_layer.weight.

Note that we may also have an opposite situation i.e., the first LoRA ckpt may have the params while the second LoRA may not. This is what I considered in #10388.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if is_peft_loaded and ".base_layer.weight" in k might be clearer that this is something when a lora is already loaded.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case where the first LoRA has extra weights than the second is ok on main

  1. Hyper-FLUX.1-dev-8steps-lora.safetensors
  2. Purz/choose-your-own-adventure

or

  1. alimama-creative/FLUX.1-Turbo-Alpha
  2. TTPlanet/Migration_Lora_flux

In this case base_param_name is set to f"{k.replace(prefix, '')}.base_layer.weight" for the 2nd LoRA and all keys exist.

If loaded in the reverse order f"{k.replace(prefix, '')}.base_layer.weight" doesn't exist for the extra weights.

  1. Purz/choose-your-own-adventure
  2. Hyper-FLUX.1-dev-8steps-lora.safetensors

or

  1. TTPlanet/Migration_Lora_flux
  2. alimama-creative/FLUX.1-Turbo-Alpha

KeyError context_embedder.base_layer.weight

So for the extra weights we use f"{k.replace(prefix, '')}.weight". If another LoRA were loaded with context_embedder it would then use context_embedder.base_layer.weight.

We could continue if f"{k.replace(prefix, '')}.base_layer.weight" is not found but the extra weights may need to be expanded.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we are considering that LoRA params for certain modules exist in the first checkpoint while they don't exist in the second checkpoint (or any other subsequent checkpoint).

In this case, we don't want to expand no? Or am I missing something? Perhaps better expressed through a short test case like the one I added here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case passes on main, the test case should be in the reverse order:

        with tempfile.TemporaryDirectory() as tmpdirname:
            denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
            self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)

            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
            pipe.unload_lora_weights()
            # Modify the state dict to exclude "x_embedder" related LoRA params.
            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
            lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
            pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")

            # Load state dict with `x_embedder`.
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
            base_param_name = (
                f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
            )
>           base_weight_param = transformer_state_dict[base_param_name]
E           KeyError: 'x_embedder.base_layer.weight'

src\diffusers\loaders\lora_pipeline.py:2471: KeyError

I think we still want to check whether the param needs to be expanded

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I understand it better now. Thanks!

Might be better to ship this PR with proper testing then. Okay with me.

@sayakpaul
Copy link
Member

Also, I gave @hlky's code snippet here a try in #10396 branch and it seems to work.

Comment on lines 2463 to 2465
transformer_base_layer_keys = {
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if is_peft_loaded and ".base_layer.weight" in k might be clearer that this is something when a lora is already loaded.

Comment on lines 2463 to 2465
transformer_base_layer_keys = {
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case where the first LoRA has extra weights than the second is ok on main

  1. Hyper-FLUX.1-dev-8steps-lora.safetensors
  2. Purz/choose-your-own-adventure

or

  1. alimama-creative/FLUX.1-Turbo-Alpha
  2. TTPlanet/Migration_Lora_flux

In this case base_param_name is set to f"{k.replace(prefix, '')}.base_layer.weight" for the 2nd LoRA and all keys exist.

If loaded in the reverse order f"{k.replace(prefix, '')}.base_layer.weight" doesn't exist for the extra weights.

  1. Purz/choose-your-own-adventure
  2. Hyper-FLUX.1-dev-8steps-lora.safetensors

or

  1. TTPlanet/Migration_Lora_flux
  2. alimama-creative/FLUX.1-Turbo-Alpha

KeyError context_embedder.base_layer.weight

So for the extra weights we use f"{k.replace(prefix, '')}.weight". If another LoRA were loaded with context_embedder it would then use context_embedder.base_layer.weight.

We could continue if f"{k.replace(prefix, '')}.base_layer.weight" is not found but the extra weights may need to be expanded.

Comment on lines 2470 to 2474
base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
f"{k.replace(prefix, '')}.base_layer.weight"
if k in transformer_base_layer_keys
else f"{k.replace(prefix, '')}.weight"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_param_name = f"{k.replace(prefix, '')}.weight"
base_layer_name = f"{k.replace(prefix, '')}.base_layer.weight"
if is_peft_loaded and base_layer_name in transformer_state_dict:
    base_param_name = base_layer_name

Something like this might be better.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hlky thanks!

Do you wanna propagate your suggestions, too?

Copy link
Collaborator

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sayakpaul. I've left some other comments but should be good to go

Thanks for the PR @maxs-kan

src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
tests/lora/test_lora_layers_flux.py Show resolved Hide resolved
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

I think the test mimics the code here that was producing the error. So, I think we should be good to go.

tests/lora/test_lora_layers_flux.py Outdated Show resolved Hide resolved
@yiyixuxu yiyixuxu merged commit 44640c8 into huggingface:main Jan 2, 2025
12 checks passed
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 2, 2025

thank you all! @maxs-kan @hlky @sayakpaul

@bghira
Copy link
Contributor

bghira commented Jan 12, 2025

this is a pretty big regression that caused some consumers to need to pull a custom build of diffusers with this patch included. can there perhaps be a hotfix pushed for v0.32.2?

@sayakpaul
Copy link
Member

Sorry this should go into a patch release. @yiyixuxu I am happy to do the patch release if you're okay.

@maxs-kan maxs-kan deleted the flux-lora-base_layer-check branch January 13, 2025 10:57
DN6 pushed a commit that referenced this pull request Jan 15, 2025
* check for base_layer key in transformer state dict

* test_lora_expansion_works_for_absent_keys

* check

* Update tests/lora/test_lora_layers_flux.py

Co-authored-by: Sayak Paul <[email protected]>

* check

* test_lora_expansion_works_for_absent_keys/test_lora_expansion_works_for_extra_keys

* absent->extra

---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants