-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathsampler_tonemap.py
44 lines (32 loc) · 1.47 KB
/
sampler_tonemap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
class ModelSamplerTonemapNoiseTest:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "custom_node_experiments"
def patch(self, model, multiplier):
def sampler_tonemap_reinhard(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
noise_pred = (cond - uncond)
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None]
noise_pred /= noise_pred_vector_magnitude
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 3 + mean) * multiplier
#reinhard
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top
return uncond + noise_pred * new_magnitude * cond_scale
m = model.clone()
m.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelSamplerTonemapNoiseTest": ModelSamplerTonemapNoiseTest,
}