Skip to content

Commit

Permalink
[SD3 LoRA Training] Fix errors when not training text encoders (#8743)
Browse files Browse the repository at this point in the history
* fix

* fix things.

Co-authored-by: Linoy Tsaban <[email protected]>

* remove patch

* apply suggestions

---------

Co-authored-by: Linoy Tsaban <[email protected]>
Co-authored-by: sayakpaul <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
  • Loading branch information
4 people committed Dec 23, 2024
1 parent 7c25331 commit 4d12d76
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def encode_prompt(
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i],
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
Expand All @@ -976,7 +976,7 @@ def encode_prompt(
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[:-1],
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
device=device if device is not None else text_encoders[-1].device,
)

Expand Down Expand Up @@ -1491,6 +1491,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
) = accelerator.prepare(
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
assert text_encoder_one is not None
assert text_encoder_two is not None
assert text_encoder_three is not None
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
Expand Down Expand Up @@ -1598,7 +1601,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
tokens_three = tokenize_prompt(tokenizer_three, prompts)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
tokenizers=[None, None, None],
prompt=prompts,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
Expand All @@ -1608,7 +1611,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
prompt=prompts,
prompt=args.instance_prompt,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
)
Expand Down Expand Up @@ -1685,10 +1688,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = itertools.chain(
transformer_lora_parameters,
text_lora_parameters_one,
text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters,
params_to_clip = (
itertools.chain(
transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
)
if args.train_text_encoder
else transformer_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

Expand Down Expand Up @@ -1741,13 +1746,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
else:
text_encoder_three = text_encoder_cls_three.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_3",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand All @@ -1767,7 +1765,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
)
del text_encoder_one, text_encoder_two, text_encoder_three
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three

torch.cuda.empty_cache()
gc.collect()

Expand Down

0 comments on commit 4d12d76

Please sign in to comment.