Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling mixed precision for dreambooth flux lora training #9565

Merged
merged 10 commits into from
Nov 1, 2024

Conversation

icsl-Jeon
Copy link
Contributor

@icsl-Jeon icsl-Jeon commented Oct 1, 2024

What does this PR do?

Hello 😄 Thank you for the awesome example!
Here, I want to make a PR that helped me train dreambooth LoRA successfully.

Fixes # (issue)

Before submitting

Who can review?

linoytsaban @sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@icsl-Jeon icsl-Jeon changed the title Handling mixed precision and add unwarp Handling mixed precision for dreambooth flux lora training Oct 1, 2024
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Just a single comment.

@linoytsaban could you also give this a look?

pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing it?

We should keep the pipeline model-level components (such as text encoders, VAE, etc.) to a reduced precision no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text encoders, vae are already in reduced precision :)
As I described in the PR description, this will change dtype of transformers
For mixed precision training, transformer was upcast into fp32 if fp16 training.
But this changes back to fp16, which leads to fp16 unscale error in clip gradient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this work?
#9549 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion! But, in this thread, I was interested in unwanted switch of fp32 into fp16 after validation, not in the computation of T5 :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. Can you provide an example command for us to verify this? Maybe @linoytsaban could give it a try?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@icsl-Jeon a friendly reminder :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be reproduced with any launch commands in the README.
accelerate launch ... --mixed_precision="fp16" ..

I checked the lora precision with

            for name, param in transformer.named_parameters():
                if 'lora' in name:
                    print(f"Layer: {name}, dtype: {param.dtype}, requires_grad: {param.requires_grad}")

Hope this help you reproduce!

@sayakpaul
Copy link
Member

Avoid dtype change for transfermer after log_validation (especially for fp16). For mixed training, the original code upcast fp16 to fp32 for mixed precision training. However, after switching pipeline dtype in log_validation, transformer dtype returns to fp16, which can lead to #6442 Actually, I had this problem when I use fp16 option. (For some reason, T5 yielded nan output in bf16, that is why I came to use fp16)

We could avoid it by running inference in autocast, no? Here's an example:

autocast_ctx = torch.autocast(accelerator.device.type)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM actually!

@linoytsaban could you also review?

@icsl-Jeon
Copy link
Contributor Author

@linoytsaban thank you in advance

Copy link
Collaborator

@linoytsaban linoytsaban left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @icsl-Jeon, LGTM!

@icsl-Jeon
Copy link
Contributor Author

Is there any action to be done for merge ?

@sayakpaul
Copy link
Member

Will merge this once the CI run is complete. Thanks a ton!

linoytsaban added a commit to linoytsaban/diffusers that referenced this pull request Oct 15, 2024
linoytsaban added a commit to linoytsaban/diffusers that referenced this pull request Oct 15, 2024
linoytsaban added a commit that referenced this pull request Oct 16, 2024
* add latent caching + smol updates

* update license

* replace with free_memory

* add --upcast_before_saving to allow saving transformer weights in lower precision

* fix models to accumulate

* fix mixed precision issue as proposed in #9565

* smol update to readme

* style

* fix caching latents

* style

* add tests for latent caching

* style

* fix latent caching

---------

Co-authored-by: Sayak Paul <[email protected]>
linoytsaban added a commit that referenced this pull request Oct 28, 2024
… + small bug fix (#9646)

* make lora target modules configurable and change the default

* style

* make lora target modules configurable and change the default

* fix bug when using prodigy and training te

* fix mixed precision training as  proposed in #9565 for full dreambooth as well

* add test and notes

* style

* address sayaks comments

* style

* fix test

---------

Co-authored-by: Sayak Paul <[email protected]>
@sayakpaul sayakpaul merged commit 3deed72 into huggingface:main Nov 1, 2024
8 checks passed
@sayakpaul
Copy link
Member

Thanks for your contributions!

a-r-r-o-w pushed a commit that referenced this pull request Nov 1, 2024
Handling mixed precision and add unwarp

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* add latent caching + smol updates

* update license

* replace with free_memory

* add --upcast_before_saving to allow saving transformer weights in lower precision

* fix models to accumulate

* fix mixed precision issue as proposed in #9565

* smol update to readme

* style

* fix caching latents

* style

* add tests for latent caching

* style

* fix latent caching

---------

Co-authored-by: Sayak Paul <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
… + small bug fix (#9646)

* make lora target modules configurable and change the default

* style

* make lora target modules configurable and change the default

* fix bug when using prodigy and training te

* fix mixed precision training as  proposed in #9565 for full dreambooth as well

* add test and notes

* style

* address sayaks comments

* style

* fix test

---------

Co-authored-by: Sayak Paul <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
Handling mixed precision and add unwarp

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants