-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
1989 lines (1803 loc) · 83.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import functools
import gc
import logging
import math
import os
import random
import shutil
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Union
import accelerate
import diffusers
import numpy as np
import open_clip
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from diffusers import (
AnimateDiffPipeline,
AutoencoderKL,
DDIMScheduler,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
LCMScheduler,
MotionAdapter,
StableDiffusionPipeline,
TextToVideoSDPipeline,
UNet2DConditionModel,
UNet3DConditionModel,
UNetMotionModel,
)
from diffusers.optimization import get_scheduler
# from diffusers.pipelines.animatediff.pipeline_animatediff import tensor2vid
from diffusers.utils import check_min_version, export_to_video, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, PeftModel, get_peft_model, get_peft_model_state_dict
from safetensors.torch import load_file
from tabulate import tabulate
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize, RandomCrop
from tqdm.auto import tqdm
from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
from args import parse_args
from dataset.webvid_dataset_wbd import Text2VideoDataset
from models.discriminator_handcraft import (
ProjectedDiscriminator,
get_dino_features,
preprocess_dino_input,
)
from models.spatial_head import IdentitySpatialHead, SpatialHead
from utils.diffusion_misc import *
from utils.dist import dist_init, dist_init_wo_accelerate, get_deepspeed_config
from utils.misc import *
from utils.wandb import setup_wandb
MAX_SEQ_LENGTH = 77
if is_wandb_available():
import wandb
logging.basicConfig(
format="%(asctime)s - %(levelname)s - [%(filename)s:%(name)s] - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = get_logger(__name__)
def save_to_local(save_dir: str, prompt: str, video):
if len(prompt) > 256:
prompt = prompt[:256]
prompt = prompt.replace(" ", "_")
logger.info(f"Saving images to {save_dir}")
export_to_video(video, os.path.join(save_dir, f"{prompt}.mp4"))
def log_validation(
vae,
unet,
args,
accelerator,
weight_dtype,
step,
name="target",
scheduler: str = "lcm",
num_inference_steps: int = 4,
add_to_trackers: bool = True,
use_lora: bool = False,
disc_gt_images: Optional[List] = None,
guidance_scale: float = 1.0,
spatial_head: Optional = None,
logger_prefix: str = "",
):
logger.info("Running validation... ")
scheduler_additional_kwargs = {}
if args.base_model_name == "animatediff":
scheduler_additional_kwargs["beta_schedule"] = "linear"
scheduler_additional_kwargs["clip_sample"] = False
scheduler_additional_kwargs["timestep_spacing"] = "linspace"
if scheduler == "lcm":
# set beta_schedule="linear" according to https://huggingface.co/wangfuyun/AnimateLCM
scheduler = LCMScheduler.from_pretrained(
args.pretrained_teacher_model,
subfolder="scheduler",
**scheduler_additional_kwargs,
)
elif scheduler == "ddim":
scheduler = DDIMScheduler.from_pretrained(
args.pretrained_teacher_model,
subfolder="scheduler",
**scheduler_additional_kwargs,
)
elif scheduler == "euler":
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
args.pretrained_teacher_model,
subfolder="scheduler",
**scheduler_additional_kwargs,
)
else:
raise ValueError(f"Scheduler {scheduler} is not supported.")
unet = deepcopy(accelerator.unwrap_model(unet))
if args.base_model_name == "animatediff":
pipeline_cls = AnimateDiffPipeline
elif args.base_model_name == "modelscope":
pipeline_cls = TextToVideoSDPipeline
if use_lora:
pipeline = pipeline_cls.from_pretrained(
args.pretrained_teacher_model,
vae=vae,
scheduler=scheduler,
revision=args.revision,
torch_dtype=weight_dtype,
safety_checker=None,
)
lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
pipeline.load_lora_weights(lora_state_dict)
pipeline.fuse_lora()
else:
pipeline = pipeline_cls.from_pretrained(
args.pretrained_teacher_model,
vae=vae,
unet=unet,
scheduler=scheduler,
revision=args.revision,
torch_dtype=weight_dtype,
safety_checker=None,
)
pipeline.set_progress_bar_config(disable=True)
pipeline = pipeline.to(accelerator.device, dtype=weight_dtype)
if (
args.enable_xformers_memory_efficient_attention
and args.base_model_name != "animatediff"
):
if is_xformers_available():
pipeline.enable_xformers_memory_efficient_attention()
else:
logger.warning(
"xformers is not available. Make sure it is installed correctly"
)
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
validation_prompts = [
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
"Cute small corgi sitting in a movie theater eating popcorn, unreal engine.",
"A Pikachu with an angry expression and red eyes, with lightning around it, hyper realistic style.",
"A dog is reading a thick book.",
"Three cats having dinner at a table at new years eve, cinematic shot, 8k.",
"An astronaut riding a pig, highly realistic dslr photo, cinematic shot.",
]
image_logs = []
for _, prompt in enumerate(validation_prompts):
output = []
with torch.autocast("cuda", dtype=weight_dtype):
output = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
height=args.resolution,
width=args.resolution,
generator=generator,
guidance_scale=guidance_scale,
output_type="latent",
).frames
if spatial_head is not None:
output = spatial_head(output)
output = pipeline.decode_latents(output)
video = tensor2vid(output, pipeline.image_processor, output_type="np")
# video should be a tensor of shape (t, h, w, 3), min 0, max 1
video = video[0]
save_dir = os.path.join(args.output_dir, "output", f"{name}-step-{step}")
if accelerator.is_main_process:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
accelerator.wait_for_everyone()
image_logs.append({"validation_prompt": prompt, "video": video})
save_to_local(save_dir, prompt, video)
if add_to_trackers:
try:
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for log in image_logs:
images = log["video"]
validation_prompt = (
f"{logger_prefix}{num_inference_steps} steps/"
+ log["validation_prompt"]
)
formatted_images = []
for image in images:
formatted_images.append(np.asarray(image))
formatted_images = np.stack(formatted_images)
tracker.writer.add_images(
validation_prompt,
formatted_images,
step,
dataformats="NHWC",
)
if disc_gt_images is not None:
for i, image in enumerate(disc_gt_images):
tracker.writer.add_image(
f"discriminator gt image/{i}",
image,
step,
dataformats="HWC",
)
elif tracker.name == "wandb":
# log image for comparison
formatted_images = []
for log in image_logs:
images = log["video"]
validation_prompt = log["validation_prompt"]
image = wandb.Image(images[0], caption=validation_prompt)
formatted_images.append(image)
if args.use_lora:
tracker.log(
{
f"{logger_prefix}validation image {num_inference_steps} steps": formatted_images
},
step=step,
)
else:
tracker.log(
{
f"{logger_prefix}validation image {num_inference_steps} steps/{name}": formatted_images
},
step=step,
)
# log video
formatted_video = []
for log in image_logs:
video = (log["video"] * 255).astype(np.uint8)
validation_prompt = log[
"validation_prompt"
] # wandb does not support video logging with caption
video = wandb.Video(
np.transpose(video, (0, 3, 1, 2)), fps=4, format="mp4"
)
formatted_video.append(video)
if args.use_lora:
tracker.log(
{
f"{logger_prefix}validation video {num_inference_steps} steps": formatted_video
},
step=step,
)
else:
tracker.log(
{
f"{logger_prefix}validation video {num_inference_steps} steps/{name}": formatted_video
},
step=step,
)
# log discriminator ground truth images
if disc_gt_images is not None:
formatted_disc_gt_images = []
for i, image in enumerate(disc_gt_images):
image = wandb.Image(
image, caption=f"discriminator gt image {i}"
)
formatted_disc_gt_images.append(image)
tracker.log(
{"discriminator gt images": formatted_disc_gt_images},
step=step,
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
except Exception as e:
logger.error(f"Failed to log images: {e}")
del pipeline
del unet
gc.collect()
torch.cuda.empty_cache()
return image_logs
def main(args):
# torch.multiprocessing.set_sharing_strategy("file_system")
dist_init()
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
setup_wandb()
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
# deepspeed_plugin=deepspeed_plugin,
)
total_batch_size = (
args.train_batch_size
* accelerator.num_processes
* args.gradient_accumulation_steps
)
# Make one log on every process with the configuration for debugging.
logger.info("Printing accelerate state", main_process_only=False)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.scale_lr:
args.learning_rate = args.learning_rate * total_batch_size / 128
args.disc_learning_rate = (
args.disc_learning_rate * total_batch_size * args.disc_tsn_num_frames / 128
)
logger.info(f"Scaling learning rate to {args.learning_rate}")
logger.info(f"Scaling discriminator learning rate to {args.disc_learning_rate}")
sorted_args = sorted(vars(args).items())
logger.info(
"\n" + tabulate(sorted_args, headers=["key", "value"], tablefmt="rounded_grid"),
main_process_only=True,
)
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
private=True,
).repo_id
try:
accelerator.wait_for_everyone()
except Exception as e:
logger.error(f"Failed to wait for everyone: {e}")
dist_init_wo_accelerate()
accelerator.wait_for_everyone()
# 1. Create the noise scheduler and the desired noise schedule.
try:
noise_scheduler = DDPMScheduler.from_pretrained(
args.pretrained_teacher_model,
subfolder="scheduler",
revision=args.teacher_revision,
rescale_betas_zero_snr=True if args.zero_snr else False,
beta_schedule=args.beta_schedule,
)
except Exception as e:
logger.error(f"Failed to load the noise scheduler: {e}")
logger.info("Switching to online pretrained checkpoint")
args.pretrained_teacher_model = args.online_pretrained_teacher_model
args.motion_adapter_path = args.online_motion_adapter_path
noise_scheduler = DDPMScheduler.from_pretrained(
args.pretrained_teacher_model,
subfolder="scheduler",
revision=args.teacher_revision,
rescale_betas_zero_snr=True if args.zero_snr else False,
beta_schedule=args.beta_schedule,
)
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=args.num_ddim_timesteps,
)
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_teacher_model,
subfolder="tokenizer",
revision=args.teacher_revision,
use_fast=False,
)
# 3. Load text encoders from SD 1.X/2.X checkpoint.
# import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_teacher_model,
subfolder="text_encoder",
revision=args.teacher_revision,
)
# 4. Load VAE from SD 1.X/2.X checkpoint
vae = AutoencoderKL.from_pretrained(
args.pretrained_teacher_model,
subfolder="vae",
revision=args.teacher_revision,
)
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
if args.base_model_name == "animatediff":
teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model,
subfolder="unet",
revision=args.teacher_revision,
)
teacher_motion_adapter = MotionAdapter.from_pretrained(args.motion_adapter_path)
teacher_unet = UNetMotionModel.from_unet2d(teacher_unet, teacher_motion_adapter)
elif args.base_model_name == "modelscope":
teacher_unet = UNet3DConditionModel.from_pretrained(
args.pretrained_teacher_model,
subfolder="unet",
revision=args.teacher_revision,
)
# 5.1 Load DINO
dino = torch.hub.load(
"facebookresearch/dinov2",
"dinov2_vits14",
)
ckpt_path = "weights/dinov2_vits14_pretrain.pth"
state_dict = torch.load(ckpt_path, map_location="cpu")
dino.load_state_dict(state_dict)
logger.info(f"Loaded DINO model from {ckpt_path}")
dino.eval()
# 5.2 Load sentence-level CLIP
open_clip_model, *_ = open_clip.create_model_and_transforms(
"ViT-g-14",
pretrained="weights/open_clip_pytorch_model.bin",
)
open_clip_tokenizer = open_clip.get_tokenizer("ViT-g-14")
# 6. Freeze teacher vae, text_encoder, and teacher_unet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)
dino.requires_grad_(False)
open_clip_model.requires_grad_(False)
normalize_fn = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
# 7. Create online student U-Net.
# For whole model fine-tuning, this will be updated by the optimizer (e.g.,
# via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if args.use_lora:
if args.base_model_name == "animatediff":
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model,
subfolder="unet",
revision=args.teacher_revision,
)
motion_adapter = MotionAdapter.from_pretrained(args.motion_adapter_path)
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
elif args.base_model_name == "modelscope":
unet = UNet3DConditionModel.from_pretrained(
args.pretrained_teacher_model,
subfolder="unet",
revision=args.teacher_revision,
)
else:
assert (
args.base_model_name == "animatediff"
), f"Please use LoRA for {args.base_model_name}"
time_cond_proj_dim = (
teacher_unet.config.time_cond_proj_dim
if "time_cond_proj_dim" in teacher_unet.config
and teacher_unet.config.time_cond_proj_dim is not None
else args.unet_time_cond_proj_dim
)
if args.base_model_name == "animatediff":
unet = UNetMotionModel.from_config(
teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim
)
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from (online) unet
target_unet = UNetMotionModel.from_config(
teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim
)
elif args.base_model_name == "modelscope":
raise NotImplementedError
unet = UNet3DConditionModel.from_config(
teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim
)
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from (online) unet
target_unet = UNet3DConditionModel.from_config(
teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim
)
# load teacher_unet weights into unet
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
target_unet.requires_grad_(False)
# freeze non-motion module parameters
for param_name, param in unet.named_parameters():
if "motion_modules" not in param_name.lower():
param.requires_grad_(False)
# count trainable parameters
trainable_params = 0
all_param = 0
for _, param in unet.named_parameters():
num_params = param.numel()
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print(
f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
)
if args.cd_target in ["learn", "hlearn"]:
if args.cd_target == "learn":
spatial_head = SpatialHead(num_channels=4, num_layers=2, kernel_size=1)
target_spatial_head = SpatialHead(
num_channels=4, num_layers=2, kernel_size=1
)
logger.info("Using SpatialHead for spatial head")
elif args.cd_target == "hlearn":
spatial_head = SpatialHead(num_channels=4, num_layers=5, kernel_size=3)
target_spatial_head = SpatialHead(
num_channels=4, num_layers=5, kernel_size=3
)
logger.info("Using SpatialHead for spatial head")
else:
raise ValueError(f"cd_target {args.cd_target} is not supported.")
spatial_head.train()
target_spatial_head.load_state_dict(spatial_head.state_dict())
target_spatial_head.train()
target_spatial_head.requires_grad_(False)
else:
spatial_head = None
target_spatial_head = None
unet.train()
# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training, copy of the weights should still be float32."
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
if args.use_lora:
if args.lora_target_modules is not None:
logger.warning(
"We are currently ignoring the `lora_target_modules` argument. As of now, LoRa does not support Conv3D layers."
)
lora_target_modules = [
module_key.strip() for module_key in args.lora_target_modules.split(",")
]
else:
lora_target_modules = [
"to_q",
"to_k",
"to_v",
"to_out.0",
"proj_in",
"proj_out",
"ff.net.0.proj",
"ff.net.2",
"conv1",
"conv2",
"conv_shortcut",
"downsamplers.0.conv",
"upsamplers.0.conv",
"time_emb_proj",
]
# Currently LoRA does not support Conv3D, thus removing the Conv3D
# layers from the list of target modules.
key_list = []
for name, module in unet.named_modules():
if any([name.endswith(module_key) for module_key in lora_target_modules]):
if args.base_model_name == "modelscope" and not (
"temp" in name and "conv" in name
):
key_list.append(name)
elif args.base_model_name == "animatediff":
key_list.append(name)
lora_config = LoraConfig(
r=args.lora_rank,
target_modules=key_list,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
unet = get_peft_model(unet, lora_config)
if (
args.from_pretrained_unet is not None
and args.from_pretrained_unet != "None"
):
# TODO currently only supports LoRA
logger.info(f"Loading pretrained UNet from {args.from_pretrained_unet}")
unet.load_adapter(
args.from_pretrained_unet,
"default",
is_trainable=True,
torch_device="cpu",
)
unet.print_trainable_parameters()
# 8.1. Create discriminator for the student U-Net.
c_dim = 1024
discriminator = ProjectedDiscriminator(
embed_dim=dino.embed_dim, c_dim=c_dim
) # TODO add dino name and patch size
if args.from_pretrained_disc is not None and args.from_pretrained_disc != "None":
try:
disc_state_dict = load_file(
os.path.join(
args.from_pretrained_disc,
"discriminator",
"diffusion_pytorch_model.safetensors",
)
)
discriminator.load_state_dict(disc_state_dict)
logger.info(
f"Loaded pretrained discriminator from {args.from_pretrained_disc}"
)
except Exception as e:
logger.error(f"Failed to load pretrained discriminator: {e}")
discriminator.train()
# 9. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
vae.to(accelerator.device)
if args.pretrained_vae_model_name_or_path is not None:
vae.to(dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
dino.to(accelerator.device, dtype=weight_dtype)
open_clip_model.to(accelerator.device)
# Move teacher_unet to device, optionally cast to weight_dtype
if not args.use_lora:
target_unet.to(accelerator.device)
teacher_unet.to(accelerator.device)
if args.cast_teacher_unet:
teacher_unet.to(dtype=weight_dtype)
if args.cd_target in ["learn", "hlearn"]:
target_spatial_head.to(accelerator.device)
# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device)
# 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
if args.use_lora:
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
unet_ = accelerator.unwrap_model(unet)
lora_state_dict = get_peft_model_state_dict(
unet_, adapter_name="default"
)
StableDiffusionPipeline.save_lora_weights(
os.path.join(output_dir, "unet_lora"), lora_state_dict
)
# save weights in peft format to be able to load them back
unet_.save_pretrained(output_dir)
discriminator_ = accelerator.unwrap_model(discriminator)
discriminator_.save_pretrained(
os.path.join(output_dir, "discriminator")
)
if args.cd_target in ["learn", "hlearn"]:
spatial_head_ = accelerator.unwrap_model(spatial_head)
spatial_head_.save_pretrained(
os.path.join(output_dir, "spatial_head")
)
target_spatial_head_ = accelerator.unwrap_model(
target_spatial_head
)
target_spatial_head_.save_pretrained(
os.path.join(output_dir, "target_spatial_head")
)
for _, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
if len(weights) > 0:
weights.pop()
else:
# only support finetune motion module for AnimateDiff
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
target_unet_ = accelerator.unwrap_model(target_unet)
target_unet_.save_motion_modules(
os.path.join(output_dir, "target_motion_modules")
)
unet_ = accelerator.unwrap_model(unet)
unet_.save_motion_modules(
os.path.join(output_dir, "motion_modules")
)
discriminator_ = accelerator.unwrap_model(discriminator)
discriminator_.save_pretrained(
os.path.join(output_dir, "discriminator")
)
if args.cd_target in ["learn", "hlearn"]:
spatial_head_ = accelerator.unwrap_model(spatial_head)
spatial_head_.save_pretrained(
os.path.join(output_dir, "spatial_head")
)
target_spatial_head_ = accelerator.unwrap_model(
target_spatial_head
)
target_spatial_head_.save_pretrained(
os.path.join(output_dir, "target_spatial_head")
)
for i, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
if len(weights) > 0:
weights.pop()
if args.use_lora:
def load_model_hook(models, input_dir):
# load the LoRA into the model
unet_ = accelerator.unwrap_model(unet)
unet_.load_adapter(
input_dir, "default", is_trainable=True, torch_device="cpu"
)
disc_state_dict = load_file(
os.path.join(
input_dir,
"discriminator",
"diffusion_pytorch_model.safetensors",
)
)
disc_ = accelerator.unwrap_model(discriminator)
disc_.load_state_dict(disc_state_dict)
del disc_state_dict
if args.cd_target in ["learn", "hlearn"]:
spatial_head_state_dict = load_file(
os.path.join(
input_dir,
"spatial_head",
"diffusion_pytorch_model.safetensors",
)
)
spatial_head_ = accelerator.unwrap_model(spatial_head)
spatial_head_.load_state_dict(spatial_head_state_dict)
del spatial_head_state_dict
target_spatial_head_state_dict = load_file(
os.path.join(
input_dir,
"target_spatial_head",
"diffusion_pytorch_model.safetensors",
)
)
target_spatial_head_ = accelerator.unwrap_model(target_spatial_head)
target_spatial_head_.load_state_dict(target_spatial_head_state_dict)
del target_spatial_head_state_dict
for _ in range(len(models)):
# pop models so that they are not loaded again
models.pop()
else:
# only support finetune motion module for AnimateDiff
def load_model_hook(models, input_dir):
target_motion_module = MotionAdapter.from_pretrained(
os.path.join(input_dir, "target_motion_modules")
)
target_unet.load_motion_modules(target_motion_module)
del target_motion_module
student_motion_module = MotionAdapter.from_pretrained(
os.path.join(input_dir, "motion_modules")
)
unet_ = accelerator.unwrap_model(unet)
unet_.load_motion_modules(student_motion_module)
del student_motion_module
state_dict = load_file(
os.path.join(
input_dir,
"discriminator",
"diffusion_pytorch_model.safetensors",
)
)
disc_ = accelerator.unwrap_model(discriminator)
disc_.load_state_dict(state_dict)
del state_dict
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# # load diffusers style into model
# load_model = UNet3DConditionModel.from_pretrained(
# input_dir, subfolder="unet"
# )
# model.register_to_config(**load_model.config)
# model.load_state_dict(load_model.state_dict())
# del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# 11. Enable optimizations
if (
args.enable_xformers_memory_efficient_attention
and args.base_model_name != "animatediff"
):
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
teacher_unet.enable_xformers_memory_efficient_attention()
if not args.use_lora:
target_unet.enable_xformers_memory_efficient_attention()
else:
logger.warning(
"xformers is not available. Make sure it is installed correctly"
)
# raise ValueError("xformers is not available. Make sure it is installed correctly")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# 12. Optimizer creation
if args.cd_target in ["learn", "hlearn"]:
unet_params = list(unet.parameters()) + list(spatial_head.parameters())
else:
unet_params = unet.parameters()
optimizer = optimizer_class(
unet_params,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
disc_optimizer = optimizer_class(
discriminator.parameters(),
lr=args.disc_learning_rate,
betas=(args.disc_adam_beta1, args.disc_adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# 13. Dataset creation and data processing
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
def compute_embeddings(
prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True
):
prompt_embeds = encode_prompt(
prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train
)
return prompt_embeds
WEBVID_DATA_SIZE = 2467378
dataset = Text2VideoDataset(
args.dataset_path,
num_train_examples=args.max_train_samples or WEBVID_DATA_SIZE,
per_gpu_batch_size=args.train_batch_size,
global_batch_size=args.train_batch_size * accelerator.num_processes,