diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2c66c341f78f..3aad7216f6aa 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -962,7 +962,7 @@ def encode_prompt( prompt=prompt, device=device if device is not None else text_encoder.device, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[i], + text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, ) clip_prompt_embeds_list.append(prompt_embeds) clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) @@ -976,7 +976,7 @@ def encode_prompt( max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[:-1], + text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, device=device if device is not None else text_encoders[-1].device, ) @@ -1491,6 +1491,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) = accelerator.prepare( transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler ) + assert text_encoder_one is not None + assert text_encoder_two is not None + assert text_encoder_three is not None else: transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler @@ -1598,7 +1601,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokens_three = tokenize_prompt(tokenizer_three, prompts) prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], - tokenizers=[None, None, tokenizer_three], + tokenizers=[None, None, None], prompt=prompts, max_sequence_length=args.max_sequence_length, text_input_ids_list=[tokens_one, tokens_two, tokens_three], @@ -1608,7 +1611,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], tokenizers=[None, None, tokenizer_three], - prompt=prompts, + prompt=args.instance_prompt, max_sequence_length=args.max_sequence_length, text_input_ids_list=[tokens_one, tokens_two, tokens_three], ) @@ -1685,10 +1688,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = itertools.chain( - transformer_lora_parameters, - text_lora_parameters_one, - text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters, + params_to_clip = ( + itertools.chain( + transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two + ) + if args.train_text_encoder + else transformer_lora_parameters ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -1741,13 +1746,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) - else: - text_encoder_three = text_encoder_cls_three.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_3", - revision=args.revision, - variant=args.variant, - ) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1767,7 +1765,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, ) - del text_encoder_one, text_encoder_two, text_encoder_three + if not args.train_text_encoder: + del text_encoder_one, text_encoder_two, text_encoder_three + torch.cuda.empty_cache() gc.collect()