Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flux LoRA] support parsing alpha from a flux lora state dict. #9236

Merged
merged 9 commits into from
Aug 22, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

See https://huggingface.slack.com/archives/C03UQJENJTV/p1724183517221439.

@apolinario can you test with this PR if the LoRA loads successfully?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@apolinario
Copy link
Collaborator

With this PR loading this LoRA just works!

pip install --upgrade git+https://github.com/huggingface/diffusers.git@c8bca51c06a477d693fb65bd9862e00a67bba5a
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('TheLastBen/Jon_Snow_Flux_LoRA', weight_name='jon_snow.safetensors')
image = pipeline('jon snow eating pizza with ketchup').images[0]
image.save("image.jpg")

image (3)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice solution, using the existing mechanism to deal with the alphas.

I have some comments but no blockers.

src/diffusers/loaders/lora_pipeline.py Show resolved Hide resolved
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it also be possible to update the scales in this model according to the alphas below and create an image, then below assert that this image is identical to images_lora_with_alpha?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand it. Would it be possible to provide a schematic of what you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIUC, you dump the state dict and then edit the alphas in the state dict, then load the state dict with edited alphas below and check that the image has changed. My suggestion is that the scalings of the loras of the transformer where edited here, using the same random changes as in the state dict, we could create an image with these altered alphas. Then below we can assert that this image should be identical to images_lora_with_alpha.

My reasoning for this is that a check that the image is unequal is always a bit weaker than an equals check, so the test would be stronger. However, I see how this is extra effort, so totally understand if it's not worth it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, happy to give this a try but I still don't understand how we can update the scaling. How are you envisioning that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yes, I checked and it seems there is no easy way. It is possible to pass a scale argument to forward, but this would always be the same value. Up to you if you think this could be a worth adding to the test or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably configure the alpha value in the LoraConfig and then generate images with that config but I think it's okay to leave that for now since we're testing it here: #9143. When either of the two PRs is merged, we can revisit. What say?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@sayakpaul sayakpaul requested a review from yiyixuxu August 21, 2024 15:03
@sayakpaul
Copy link
Member Author

@yiyixuxu would be great if you could give this a check.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this argument there to prevent breaking? i.e. load_state_dict() would be used outside our load_lora_weights method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We use lora_state_dict() in the training (example usage outside of load_lora_weights()).

@bghira
Copy link
Contributor

bghira commented Aug 21, 2024

amazing! thank you. also you choose the most mesmerising images for demoing features :D i am now hungry

@sayakpaul sayakpaul merged commit 5090b09 into main Aug 22, 2024
18 checks passed
@sayakpaul sayakpaul deleted the flux-lora-alpha branch August 22, 2024 01:31
@sayakpaul
Copy link
Member Author

How are you calculating memorization here? There's a very nice paper from @YuxinWenRick here about this topic: https://openreview.net/forum?id=84n3UwkH7b (which also happens to be an ICLR oral).

yiyixuxu pushed a commit that referenced this pull request Aug 24, 2024
* support parsing alpha from a flux lora state dict.

* conditional import.

* fix breaking changes.

* safeguard alpha.

* fix
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* support parsing alpha from a flux lora state dict.

* conditional import.

* fix breaking changes.

* safeguard alpha.

* fix
@bghira
Copy link
Contributor

bghira commented Jan 28, 2025

@sayakpaul interestingly this doesn't seem to solve the problem of PEFT LoRAs created w/ Diffusers lacking the alpha property

@sayakpaul
Copy link
Member Author

Yeah it's not supposed to. The underlying checkpoints with which this PR was tested weren't supposedly trained with PEFT.

I opened PRs in the past that would allow parsing the metadata from a given LoRA ckpt (and allow loading of alphas hence) but it didn't seem worth the effort. Will try revisiting later.

@bghira
Copy link
Contributor

bghira commented Jan 28, 2025

keeping alpha fixed at 1 really simplifies the LR scaling and some other hyperparam searches so it's definitely worth having

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants