-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Flux multiple Lora loading bug #10388
Fix Flux multiple Lora loading bug #10388
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @maxs-kan, thanks for your contribution, can you share some example lora checkpoints that may lead to a bug? |
Sure, try in the same order: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code
from diffusers import FluxPipeline
from huggingface_hub import hf_hub_download
import torch
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(
hf_hub_download("TTPlanet/Migration_Lora_flux", "Migration_Lora_cloth.safetensors"),
adapter_name="cloth",
)
pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo")
transformer_base_layer_keys = { | ||
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note base_layer
substring can only be present when the underlying pipeline has at least one LoRA loaded that affects the layer under consideration. So, perhaps it's better to have an is_peft_loaded
check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In your PR description you mention:
If the first loaded Lora model does not have weights for layer n, and the second one does, loading the second model will lead to an error since the transformer state dict currently does not have key n.base_layer.weight.
Note that we may also have an opposite situation i.e., the first LoRA ckpt may have the params while the second LoRA may not. This is what I considered in #10388.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_peft_loaded and ".base_layer.weight" in k
might be clearer that this is something when a lora is already loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The case where the first LoRA has extra weights than the second is ok on main
Hyper-FLUX.1-dev-8steps-lora.safetensors
Purz/choose-your-own-adventure
or
alimama-creative/FLUX.1-Turbo-Alpha
TTPlanet/Migration_Lora_flux
In this case base_param_name
is set to f"{k.replace(prefix, '')}.base_layer.weight" for the 2nd LoRA and all keys exist.
If loaded in the reverse order f"{k.replace(prefix, '')}.base_layer.weight"
doesn't exist for the extra weights.
Purz/choose-your-own-adventure
Hyper-FLUX.1-dev-8steps-lora.safetensors
or
TTPlanet/Migration_Lora_flux
alimama-creative/FLUX.1-Turbo-Alpha
KeyError context_embedder.base_layer.weight
So for the extra weights we use f"{k.replace(prefix, '')}.weight"
. If another LoRA were loaded with context_embedder
it would then use context_embedder.base_layer.weight
.
We could continue
if f"{k.replace(prefix, '')}.base_layer.weight"
is not found but the extra weights may need to be expanded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, we are considering that LoRA params for certain modules exist in the first checkpoint while they don't exist in the second checkpoint (or any other subsequent checkpoint).
In this case, we don't want to expand no? Or am I missing something? Perhaps better expressed through a short test case like the one I added here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test case passes on main
, the test case should be in the reverse order:
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
# Load state dict with `x_embedder`.
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
)
> base_weight_param = transformer_state_dict[base_param_name]
E KeyError: 'x_embedder.base_layer.weight'
src\diffusers\loaders\lora_pipeline.py:2471: KeyError
I think we still want to check whether the param needs to be expanded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, I understand it better now. Thanks!
Might be better to ship this PR with proper testing then. Okay with me.
Also, I gave @hlky's code snippet here a try in #10396 branch and it seems to work. |
transformer_base_layer_keys = { | ||
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_peft_loaded and ".base_layer.weight" in k
might be clearer that this is something when a lora is already loaded.
transformer_base_layer_keys = { | ||
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The case where the first LoRA has extra weights than the second is ok on main
Hyper-FLUX.1-dev-8steps-lora.safetensors
Purz/choose-your-own-adventure
or
alimama-creative/FLUX.1-Turbo-Alpha
TTPlanet/Migration_Lora_flux
In this case base_param_name
is set to f"{k.replace(prefix, '')}.base_layer.weight" for the 2nd LoRA and all keys exist.
If loaded in the reverse order f"{k.replace(prefix, '')}.base_layer.weight"
doesn't exist for the extra weights.
Purz/choose-your-own-adventure
Hyper-FLUX.1-dev-8steps-lora.safetensors
or
TTPlanet/Migration_Lora_flux
alimama-creative/FLUX.1-Turbo-Alpha
KeyError context_embedder.base_layer.weight
So for the extra weights we use f"{k.replace(prefix, '')}.weight"
. If another LoRA were loaded with context_embedder
it would then use context_embedder.base_layer.weight
.
We could continue
if f"{k.replace(prefix, '')}.base_layer.weight"
is not found but the extra weights may need to be expanded.
base_param_name = ( | ||
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight" | ||
f"{k.replace(prefix, '')}.base_layer.weight" | ||
if k in transformer_base_layer_keys | ||
else f"{k.replace(prefix, '')}.weight" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_param_name = f"{k.replace(prefix, '')}.weight"
base_layer_name = f"{k.replace(prefix, '')}.base_layer.weight"
if is_peft_loaded and base_layer_name in transformer_state_dict:
base_param_name = base_layer_name
Something like this might be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hlky thanks!
Do you wanna propagate your suggestions, too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sayakpaul. I've left some other comments but should be good to go
Thanks for the PR @maxs-kan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
I think the test mimics the code here that was producing the error. So, I think we should be good to go.
Co-authored-by: Sayak Paul <[email protected]>
thank you all! @maxs-kan @hlky @sayakpaul |
this is a pretty big regression that caused some consumers to need to pull a custom build of diffusers with this patch included. can there perhaps be a hotfix pushed for v0.32.2? |
Sorry this should go into a patch release. @yiyixuxu I am happy to do the patch release if you're okay. |
* check for base_layer key in transformer state dict * test_lora_expansion_works_for_absent_keys * check * Update tests/lora/test_lora_layers_flux.py Co-authored-by: Sayak Paul <[email protected]> * check * test_lora_expansion_works_for_absent_keys/test_lora_expansion_works_for_extra_keys * absent->extra --------- Co-authored-by: hlky <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
The current approach of checking for a key with a
base_layer
suffix may lead to a bug when multiple Lora models are loaded. If the first loaded Lora model does not have weights for layern
, and the second one does, loading the second model will lead to an error since the transformer state dict currently does not have keyn.base_layer.weight
. So I explicitly check for the presence of a key with thebase_layer
suffix.@yiyixuxu