Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling mixed precision for dreambooth flux lora training #9565

Merged
merged 10 commits into from
Nov 1, 2024
6 changes: 4 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing it?

We should keep the pipeline model-level components (such as text encoders, VAE, etc.) to a reduced precision no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text encoders, vae are already in reduced precision :)
As I described in the PR description, this will change dtype of transformers
For mixed precision training, transformer was upcast into fp32 if fp16 training.
But this changes back to fp16, which leads to fp16 unscale error in clip gradient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this work?
#9549 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion! But, in this thread, I was interested in unwanted switch of fp32 into fp16 after validation, not in the computation of T5 :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. Can you provide an example command for us to verify this? Maybe @linoytsaban could give it a try?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@icsl-Jeon a friendly reminder :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be reproduced with any launch commands in the README.
accelerate launch ... --mixed_precision="fp16" ..

I checked the lora precision with

            for name, param in transformer.named_parameters():
                if 'lora' in name:
                    print(f"Layer: {name}, dtype: {param.dtype}, requires_grad: {param.requires_grad}")

Hope this help you reproduce!

pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1706,7 +1706,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# handle guidance
if transformer.config.guidance_embeds:
if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
Expand Down Expand Up @@ -1819,6 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand Down
Loading