Skip to content

Commit

Permalink
feat: Shared Diffusion Model Loader (Inspire)
Browse files Browse the repository at this point in the history
feat: Shared Text Encoder Loader (Inspire)

#205
  • Loading branch information
ltdrdata committed Jan 15, 2025
1 parent bcbed07 commit 0284075
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 2 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ This repository offers various extension nodes for ComfyUI. Nodes here have diff
* `Shared Checkpoint Loader (Inspire)`: When loading a checkpoint through this loader, it is automatically cached in the backend cache. Additionally, if it is already cached, it retrieves it from the cache instead of loading it anew.
* When `key_opt` is empty, the `ckpt_name` is set as the cache key. The cache key output can be used for deletion purposes with Remove Back End.
* This node resolves the issue of reloading checkpoints during workflow switching.
* `Shared Diffusion Model Loader (Inspire)`: Similar to the `Shared Checkpoint Loader (Inspire)` but used for loading Diffusion models instead of Checkpoints.
* `Shared Text Encoder Loader (Inspire)`: Similar to the `Shared Checkpoint Loader (Inspire)` but used for loading Text Encoder models instead of Checkpoints.
* This node also functions as a unified node for `CLIPLoader`, `DualCLIPLoader`, and `TripleCLIPLoader`.
* `Stable Cascade Checkpoint Loader (Inspire)`: This node provides a feature that allows you to load the `stage_b` and `stage_c` checkpoints of Stable Cascade at once, and it also provides a backend caching feature, optionally.
* `Is Cached (Inspire)`: Returns whether the cache exists.

Expand Down
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import importlib

version_code = [1, 9, 1]
version_code = [1, 10]
version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '')
print(f"### Loading: ComfyUI-Inspire-Pack ({version_str})")

Expand Down
174 changes: 174 additions & 0 deletions inspire/backend_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from .libs.utils import TaggedCache, any_typ

import logging

root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
settings_file = os.path.join(root_dir, 'cache_settings.json')
try:
Expand Down Expand Up @@ -401,6 +403,174 @@ def IS_CHANGED(ckpt_name, key_opt, mode='Auto'):
return (None, cache_weak_hash(key))


class LoadDiffusionModelShared(nodes.UNETLoader):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "Diffusion Model Name"}),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],),
"key_opt": ("STRING", {"multiline": False, "placeholder": "If empty, use 'model_name' as the key."}),
"mode": (['Auto', 'Override Cache', 'Read Only'],),
}
}
RETURN_TYPES = ("MODEL", "STRING")
RETURN_NAMES = ("model", "cache key")

FUNCTION = "doit"

CATEGORY = "InspirePack/Backend"

def doit(self, model_name, weight_dtype, key_opt, mode='Auto'):
if mode == 'Read Only':
if key_opt.strip() == '':
raise Exception("[LoadDiffusionModelShared] key_opt cannot be omit if mode is 'Read Only'")
key = key_opt.strip()
elif key_opt.strip() == '':
key = f"{model_name}_{weight_dtype}"
else:
key = key_opt.strip()

if key not in cache or mode == 'Override Cache':
model = self.load_unet(model_name, weight_dtype)[0]
update_cache(key, "diffusion", (False, model))
print(f"[Inspire Pack] LoadDiffusionModelShared: diffusion model '{model_name}' is cached to '{key}'.")
else:
_, (_, model) = cache[key]
print(f"[Inspire Pack] LoadDiffusionModelShared: Cached diffusion model '{key}' is loaded. (Loading skip)")

return model, key

@staticmethod
def IS_CHANGED(model_name, weight_dtype, key_opt, mode='Auto'):
if mode == 'Read Only':
if key_opt.strip() == '':
raise Exception("[LoadDiffusionModelShared] key_opt cannot be omit if mode is 'Read Only'")
key = key_opt.strip()
elif key_opt.strip() == '':
key = f"{model_name}_{weight_dtype}"
else:
key = key_opt.strip()

if mode == 'Read Only':
return None, cache_weak_hash(key)
elif mode == 'Override Cache':
return model_name, key

return None, cache_weak_hash(key)


class LoadTextEncoderShared:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_name1": (folder_paths.get_filename_list("text_encoders"), ),
"model_name2": (["None"] + folder_paths.get_filename_list("text_encoders"), ),
"model_name3": (["None"] + folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "sdxl", "flux", "hunyuan_video"], ),
"key_opt": ("STRING", {"multiline": False, "placeholder": "If empty, use 'model_name' as the key."}),
"mode": (['Auto', 'Override Cache', 'Read Only'],),
},
"optional": { "device": (["default", "cpu"], {"advanced": True}), }
}
RETURN_TYPES = ("CLIP", "STRING")
RETURN_NAMES = ("clip", "cache key")

FUNCTION = "doit"

CATEGORY = "InspirePack/Backend"

DESCRIPTION = \
("[Recipes single]\n"
"stable_diffusion: clip-l\n"
"stable_cascade: clip-g\n"
"sd3: t5 / clip-g / clip-l\n"
"stable_audio: t5\n"
"mochi: t5\n"
"cosmos: old t5 xxl\n\n"
"[Recipes dual]\n"
"sdxl: clip-l, clip-g\n"
"sd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\n"
"flux: clip-l, t5\n\n"
"[Recipes triple]\n"
"sd3: clip-l, clip-g, t5")

def doit(self, model_name1, model_name2, model_name3, type, key_opt, mode='Auto', device="default"):
if mode == 'Read Only':
if key_opt.strip() == '':
raise Exception("[LoadTextEncoderShared] key_opt cannot be omit if mode is 'Read Only'")
key = key_opt.strip()
elif key_opt.strip() == '':
key = model_name1
if model_name2 is not None:
key += f"_{model_name2}"
if model_name3 is not None:
key += f"_{model_name3}"
key += f"_{type}_{device}"
else:
key = key_opt.strip()

if key not in cache or mode == 'Override Cache':
if model_name2 != "None" and model_name3 != "None": # triple text encoder
if len({model_name1, model_name2, model_name3}) < 3:
logging.error("[LoadTextEncoderShared] The same model has been selected multiple times.")
raise ValueError("The same model has been selected multiple times.")

if type not in ["sd3"]:
logging.error("[LoadTextEncoderShared] Currently, the triple text encoder is only supported in `sd3`.")
raise ValueError("Currently, the triple text encoder is only supported in `sd3`.")

res = nodes.NODE_CLASS_MAPPINGS["TripleCLIPLoader"]().load_clip(model_name1, model_name2, model_name3)[0]

elif model_name2 != "None" or model_name3 != "None": # dual text encoder
second_model = model_name2 if model_name2 != "None" else model_name3

if model_name1 == second_model:
logging.error("[LoadTextEncoderShared] You have selected the same model for both.")
raise ValueError("[LoadTextEncoderShared] You have selected the same model for both.")

if type not in ["sdxl", "sd3", "flux", "hunyuan_video"]:
logging.error("[LoadTextEncoderShared] Currently, the triple text encoder is only supported in `sdxl, sd3, flux, hunyuan_video`.")
raise ValueError("Currently, the triple text encoder is only supported in `sdxl, sd3, flux, hunyuan_video`.")

res = nodes.NODE_CLASS_MAPPINGS["DualCLIPLoader"]().load_clip(model_name1, second_model, type=type, device=device)[0]

else: # single text encoder
if type not in ["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"]:
logging.error("[LoadTextEncoderShared] Currently, the single text encoder is only supported in `stable_diffusion, stable_cascade, sd3, stable_audio, mochi, ltxv, pixart, cosmos`.")
raise ValueError("Currently, the single text encoder is only supported in `stable_diffusion, stable_cascade, sd3, stable_audio, mochi, ltxv, pixart, cosmos`.")

res = nodes.NODE_CLASS_MAPPINGS["CLIPLoader"]().load_clip(model_name1, type=type, device=device)[0]

update_cache(key, "diffusion", (False, res))
print(f"[Inspire Pack] LoadTextEncoderShared: text encoder model set is cached to '{key}'.")
else:
_, (_, res) = cache[key]
print(f"[Inspire Pack] LoadTextEncoderShared: Cached text encoder model set '{key}' is loaded. (Loading skip)")

return res, key

@staticmethod
def IS_CHANGED(model_name1, model_name2, model_name3, type, key_opt, mode='Auto', device="default"):
if mode == 'Read Only':
if key_opt.strip() == '':
raise Exception("[LoadTextEncoderShared] key_opt cannot be omit if mode is 'Read Only'")
key = key_opt.strip()
elif key_opt.strip() == '':
key = model_name1
if model_name2 is not None:
key += f"_{model_name2}"
if model_name3 is not None:
key += f"_{model_name3}"
key += f"_{type}_{device}"
else:
key = key_opt.strip()

if mode == 'Read Only':
return None, cache_weak_hash(key)
elif mode == 'Override Cache':
return f"{model_name1}_{model_name2}_{model_name3}_{type}_{device}", key

return None, cache_weak_hash(key)


class StableCascade_CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -558,6 +728,8 @@ def doit(self, value, mode, unique_id):
"RemoveBackendDataNumberKey //Inspire": RemoveBackendDataNumberKey,
"ShowCachedInfo //Inspire": ShowCachedInfo,
"CheckpointLoaderSimpleShared //Inspire": CheckpointLoaderSimpleShared,
"LoadDiffusionModelShared //Inspire": LoadDiffusionModelShared,
"LoadTextEncoderShared //Inspire": LoadTextEncoderShared,
"StableCascade_CheckpointLoader //Inspire": StableCascade_CheckpointLoader,
"IsCached //Inspire": IsCached,
# "CacheBridge //Inspire": CacheBridge,
Expand All @@ -574,6 +746,8 @@ def doit(self, value, mode, unique_id):
"RemoveBackendDataNumberKey //Inspire": "Remove Backend Data [NumberKey] (Inspire)",
"ShowCachedInfo //Inspire": "Show Cached Info (Inspire)",
"CheckpointLoaderSimpleShared //Inspire": "Shared Checkpoint Loader (Inspire)",
"LoadDiffusionModelShared //Inspire": "Shared Diffusion Model Loader (Inspire)",
"LoadTextEncoderShared //Inspire": "Shared Text Encoder Loader (Inspire)",
"StableCascade_CheckpointLoader //Inspire": "Stable Cascade Checkpoint Loader (Inspire)",
"IsCached //Inspire": "Is Cached (Inspire)",
# "CacheBridge //Inspire": "Cache Bridge (Inspire)"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-inspire-pack"
description = "This extension provides various nodes to support Lora Block Weight, Regional Nodes, Backend Cache, Prompt Utils, List Utils, Noise(Seed) Utils, ... and the Impact Pack."
version = "1.9.1"
version = "1.10"
license = { file = "LICENSE" }
dependencies = ["matplotlib", "cachetools"]

Expand Down

0 comments on commit 0284075

Please sign in to comment.