Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SD3 LoRA Training] Fix errors when not training text encoders #8743

Merged
merged 6 commits into from
Jul 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def encode_prompt(
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i],
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
Expand All @@ -976,7 +976,7 @@ def encode_prompt(
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[:-1],
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
device=device if device is not None else text_encoders[-1].device,
)

Expand Down Expand Up @@ -1491,6 +1491,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
) = accelerator.prepare(
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
assert text_encoder_one is not None
assert text_encoder_two is not None
assert text_encoder_three is not None
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
Expand Down Expand Up @@ -1598,7 +1601,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
tokens_three = tokenize_prompt(tokenizer_three, prompts)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
tokenizers=[None, None, None],
prompt=prompts,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
Expand All @@ -1608,7 +1611,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
prompt=prompts,
prompt=args.instance_prompt,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
)
Expand Down Expand Up @@ -1685,10 +1688,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = itertools.chain(
transformer_lora_parameters,
text_lora_parameters_one,
text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters,
params_to_clip = (
itertools.chain(
transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
)
if args.train_text_encoder
else transformer_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

Expand Down Expand Up @@ -1741,13 +1746,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
else:
text_encoder_three = text_encoder_cls_three.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_3",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand All @@ -1767,7 +1765,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
)
del text_encoder_one, text_encoder_two, text_encoder_three
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three

torch.cuda.empty_cache()
gc.collect()

Expand Down
Loading