-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmidas_model.py
122 lines (107 loc) · 3.81 KB
/
midas_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import cv2
import pandas as pd
import gc
import torch
import lpips
from PIL import Image, ImageOps
import requests
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from tqdm import tqdm
from resize_right import resize
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
import numpy as np
from numpy import asarray
from midas.dpt_depth import DPTDepthModel
from midas.midas_net import MidasNet
from midas.midas_net_custom import MidasNet_small
from midas.transforms import Resize, NormalizeImage, PrepareForNet
import comfy.model_management
default_models = {}
def init_midas_depth_model(midas_model_type="dpt_large", optimize=True):
global default_models
midas_model = None
net_w = None
net_h = None
resize_mode = None
normalization = None
print(f"Initializing MiDaS '{midas_model_type}' depth model...")
# load network
midas_model_path = default_models[midas_model_type]
assert False # TODO
if midas_model_type == "dpt_large": # DPT-Large
midas_model = DPTDepthModel(
path=midas_model_path,
backbone="vitl16_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif midas_model_type == "dpt_hybrid": # DPT-Hybrid
midas_model = DPTDepthModel(
path=midas_model_path,
backbone="vitb_rn50_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif midas_model_type == "dpt_hybrid_nyu": # DPT-Hybrid-NYU
midas_model = DPTDepthModel(
path=midas_model_path,
backbone="vitb_rn50_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif midas_model_type == "midas_v21":
midas_model = MidasNet(midas_model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
elif midas_model_type == "midas_v21_small":
midas_model = MidasNet_small(midas_model_path, features=64, backbone="efficientnet_lite3",
exportable=True, non_negative=True, blocks={'expand': True})
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
else:
print(f"midas_model_type '{midas_model_type}' not implemented")
assert False
midas_transform = T.Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
midas_model.eval()
device = comfy.model_management.get_torch_device()
if optimize is True:
if device == torch.device("cuda"):
midas_model = midas_model.to(memory_format=torch.channels_last)
midas_model = midas_model.half()
midas_model.to(device)
print(f"MiDaS '{midas_model_type}' depth model initialized.")
return midas_model, midas_transform, net_w, net_h, resize_mode, normalization