-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flux LoRA] support parsing alpha from a flux lora state dict. #9236
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
With this PR loading this LoRA just works! pip install --upgrade git+https://github.com/huggingface/diffusers.git@c8bca51c06a477d693fb65bd9862e00a67bba5a from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('TheLastBen/Jon_Snow_Flux_LoRA', weight_name='jon_snow.safetensors')
image = pipeline('jon snow eating pizza with ketchup').images[0]
image.save("image.jpg") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice solution, using the existing mechanism to deal with the alphas.
I have some comments but no blockers.
state_dict_with_alpha = safetensors.torch.load_file( | ||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") | ||
) | ||
alpha_dict = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it also be possible to update the scales in this model according to the alphas below and create an image, then below assert that this image is identical to images_lora_with_alpha
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully understand it. Would it be possible to provide a schematic of what you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So IIUC, you dump the state dict and then edit the alphas in the state dict, then load the state dict with edited alphas below and check that the image has changed. My suggestion is that the scalings of the loras of the transformer where edited here, using the same random changes as in the state dict, we could create an image with these altered alphas. Then below we can assert that this image should be identical to images_lora_with_alpha
.
My reasoning for this is that a check that the image is unequal is always a bit weaker than an equals check, so the test would be stronger. However, I see how this is extra effort, so totally understand if it's not worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, happy to give this a try but I still don't understand how we can update the scaling. How are you envisioning that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, yes, I checked and it seems there is no easy way. It is possible to pass a scale
argument to forward
, but this would always be the same value. Up to you if you think this could be a worth adding to the test or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could probably configure the alpha
value in the LoraConfig
and then generate images with that config but I think it's okay to leave that for now since we're testing it here: #9143. When either of the two PRs is merged, we can revisit. What say?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good
@yiyixuxu would be great if you could give this a check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
def lora_state_dict( | ||
cls, | ||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||
return_alphas: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this argument there to prevent breaking? i.e. load_state_dict()
would be used outside our load_lora_weights
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We use lora_state_dict()
in the training (example usage outside of load_lora_weights()
).
amazing! thank you. also you choose the most mesmerising images for demoing features :D i am now hungry |
How are you calculating memorization here? There's a very nice paper from @YuxinWenRick here about this topic: https://openreview.net/forum?id=84n3UwkH7b (which also happens to be an ICLR oral). |
* support parsing alpha from a flux lora state dict. * conditional import. * fix breaking changes. * safeguard alpha. * fix
* support parsing alpha from a flux lora state dict. * conditional import. * fix breaking changes. * safeguard alpha. * fix
@sayakpaul interestingly this doesn't seem to solve the problem of PEFT LoRAs created w/ Diffusers lacking the alpha property |
Yeah it's not supposed to. The underlying checkpoints with which this PR was tested weren't supposedly trained with PEFT. I opened PRs in the past that would allow parsing the metadata from a given LoRA ckpt (and allow loading of alphas hence) but it didn't seem worth the effort. Will try revisiting later. |
keeping alpha fixed at 1 really simplifies the LR scaling and some other hyperparam searches so it's definitely worth having |
What does this PR do?
See https://huggingface.slack.com/archives/C03UQJENJTV/p1724183517221439.
@apolinario can you test with this PR if the LoRA loads successfully?