-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcomfy_latent_upscaler.py
105 lines (91 loc) · 2.66 KB
/
comfy_latent_upscaler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
import torch.nn as nn
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
class Upscaler(nn.Module):
"""
Basic NN layout, ported from:
https://github.com/city96/SD-Latent-Upscaler/blob/main/upscaler.py
"""
version = 2.1 # network revision
def head(self):
return [
nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad),
nn.ReLU(),
nn.Upsample(scale_factor=self.fac, mode="nearest"),
nn.ReLU(),
]
def core(self):
layers = []
for _ in range(self.depth):
layers += [
nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad),
nn.ReLU(),
]
return layers
def tail(self):
return [
nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad),
]
def __init__(self, fac, depth=16):
super().__init__()
self.size = 64 # Conv2d size
self.chan = 4 # in/out channels
self.depth = depth # no. of layers
self.fac = fac # scale factor
self.krn = 3 # kernel size
self.pad = 1 # padding
self.sequential = nn.Sequential(
*self.head(),
*self.core(),
*self.tail(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.sequential(x)
class LatentUpscaler:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", ),
"latent_ver": (["v1", "xl"],),
"scale_factor": (["1.25", "1.5", "2.0"],),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
CATEGORY = "latent"
def upscale(self, samples, latent_ver, scale_factor):
model = Upscaler(scale_factor)
filename = f"latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors"
local = os.path.join(
os.path.join(os.path.dirname(os.path.realpath(__file__)),"models"),
filename
)
if os.path.isfile(local):
print("LatentUpscaler: Using local model")
weights = local
else:
print("LatentUpscaler: Using HF Hub model")
weights = str(hf_hub_download(
repo_id="city96/SD-Latent-Upscaler",
filename=filename)
)
model.load_state_dict(load_file(weights))
lt = samples["samples"]
lt = model(lt)
del model
if "noise_mask" in samples.keys():
# expand the noise mask to the same shape as the latent
mask = torch.nn.functional.interpolate(samples['noise_mask'], scale_factor=float(scale_factor), mode='bicubic')
return ({"samples": lt, "noise_mask": mask},)
return ({"samples": lt},)
NODE_CLASS_MAPPINGS = {
"LatentUpscaler": LatentUpscaler,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LatentUpscaler": "Latent Upscaler"
}