-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathvae.py
48 lines (47 loc) · 1.24 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from diffusers import AutoencoderKL
def get_vae(version, file_path=None, fp16=False):
"""Load VAE from file or default hf repo. fp16 only works from hf"""
vae = None
dtype = torch.float16 if fp16 else torch.float32
if version == "v1" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=512,
)
elif version == "v1":
vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="vae",
torch_dtype=dtype,
)
elif version == "v2" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=768,
)
elif version == "v2":
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-2-1",
subfolder="vae",
torch_dtype=dtype,
)
elif version == "xl" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=1024
)
elif version == "xl" and fp16:
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
elif version == "xl":
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae"
)
else:
input("Invalid VAE version. Press any key to exit")
exit(1)
return vae