forked from pharmapsychotic/clip-interrogator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtungsten_model.py
102 lines (88 loc) · 2.89 KB
/
tungsten_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import List
from tungstenkit import BaseIO, Field, Image, Option, define_model
from clip_interrogator import Config, Interrogator
CLIP_MODEL_NAMES = [
"ViT-L-14/openai",
"ViT-H-14/laion2b_s32b_b79k",
"ViT-bigG-14/laion2b_s39b_b160k",
]
class Input(BaseIO):
input_image: Image = Field(description="Input image")
clip_model_name: str = Option(
default="ViT-L-14/openai",
choices=[
"ViT-L-14/openai",
"ViT-H-14/laion2b_s32b_b79k",
"ViT-bigG-14/laion2b_s39b_b160k",
],
description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.",
)
mode: str = Option(
default="best",
choices=["best", "classic", "fast", "negative"],
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
)
class Output(BaseIO):
interrogated: str
@define_model(
input=Input,
output=Output,
gpu=True,
cuda_version="11.8",
python_version="3.10",
system_packages=["libgl1-mesa-glx", "libglib2.0-0"],
python_packages=[
"safetensors==0.3.3",
"tqdm==4.66.1",
"open_clip_torch==2.20.0",
"accelerate==0.22.0",
"transformers==4.33.1",
],
batch_size=1,
)
class CLIPInterrogator:
@staticmethod
def post_build():
"""Download weights"""
ci = Interrogator(
Config(
clip_model_name="ViT-L-14/openai",
clip_model_path="cache",
device="cpu",
)
)
for clip_model_name in CLIP_MODEL_NAMES:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
def setup(self):
"""Load weights"""
self.ci = Interrogator(
Config(
clip_model_name="ViT-L-14/openai",
clip_model_path="cache",
device="cuda:0",
)
)
def predict(self, inputs: List[Input]) -> str:
"""Run a single prediction on the model"""
input = inputs[0]
image = input.input_image
clip_model_name = input.clip_model_name
mode = input.mode
image = image.to_pil_image()
self.switch_model(clip_model_name)
if mode == "best":
ret = self.ci.interrogate(image)
elif mode == "classic":
ret = self.ci.interrogate_classic(image)
elif mode == "fast":
ret = self.ci.interrogate_fast(image)
elif mode == "negative":
ret = self.ci.interrogate_negative(image)
else:
raise RuntimeError(f"Unknown mode: {ret}")
return [Output(interrogated=ret)]
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name
self.ci.load_clip_model()