Skip to content

Commit

Permalink
fixed: LBW - Certain text encoders, such as T5, were not being reflec…
Browse files Browse the repository at this point in the history
…ted in CLIP.

#185
  • Loading branch information
ltdrdata committed Nov 15, 2024
1 parent 3958690 commit 3bcc1d4
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 16 deletions.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import importlib

version_code = [1, 6]
version_code = [1, 6, 1]
version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '')
print(f"### Loading: ComfyUI-Inspire-Pack ({version_str})")

Expand Down
47 changes: 33 additions & 14 deletions inspire/lora_block_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, inver

if k in muted_weights:
pass
elif 'text' in k:
elif 'text' in k or 'encoder' in k:
new_clip.add_patches({k: weights}, strength_clip * ratio)
else:
new_modelpatcher.add_patches({k: weights}, strength_model * ratio)
Expand Down Expand Up @@ -546,7 +546,7 @@ def doit(model, clip, strength_model, strength_clip, lbw_model):

if k in muted_weights:
pass
elif 'text' in k:
elif 'text' in k or 'encoder' in k:
new_clip.add_patches({k: weights}, strength_clip * ratio)
else:
new_modelpatcher.add_patches({k: weights}, strength_model * ratio)
Expand Down Expand Up @@ -822,9 +822,13 @@ def parse_unet_num(s):
output_blocks = []
output_blocks_map = {}

text_block_count = set()
text_blocks = []
text_blocks_map = {}
text_block_count1 = set()
text_blocks1 = []
text_blocks_map1 = {}

text_block_count2 = set()
text_blocks2 = []
text_blocks_map2 = {}

double_block_count = set()
double_blocks = []
Expand Down Expand Up @@ -902,12 +906,23 @@ def parse_unet_num(s):
k_unet_num = k_unet[len("er.text_model.encoder.layers."):len("er.text_model.encoder.layers.")+2]
k_unet_int = parse_unet_num(k_unet_num)

text_block_count.add(k_unet_int)
text_blocks.append(k_unet)
if k_unet_int in text_blocks_map:
text_blocks_map[k_unet_int].append(k_unet)
text_block_count1.add(k_unet_int)
text_blocks1.append(k_unet)
if k_unet_int in text_blocks_map1:
text_blocks_map1[k_unet_int].append(k_unet)
else:
text_blocks_map1[k_unet_int] = [k_unet]

elif k_unet.startswith("r.encoder.block."):
k_unet_num = k_unet[len("r.encoder.block."):len("r.encoder.block.")+2]
k_unet_int = parse_unet_num(k_unet_num)

text_block_count2.add(k_unet_int)
text_blocks2.append(k_unet)
if k_unet_int in text_blocks_map2:
text_blocks_map2[k_unet_int].append(k_unet)
else:
text_blocks_map[k_unet_int] = [k_unet]
text_blocks_map2[k_unet_int] = [k_unet]

else:
others.append(k_unet)
Expand Down Expand Up @@ -951,10 +966,14 @@ def parse_unet_num(s):
for x in single_keys:
text += f" SINGLE{x}: {len(single_blocks_map[x])}\n"

text += f"\n-------[Base blocks] ({len(text_block_count) + len(others)}, Subs={len(text_blocks) + len(others)})-------\n"
text_keys = sorted(text_blocks_map.keys())
for x in text_keys:
text += f" TXT_ENC{x}: {len(text_blocks_map[x])}\n"
text += f"\n-------[Base blocks] ({len(text_block_count1) + len(text_block_count2) + len(others)}, Subs={len(text_blocks1) + len(text_blocks2) + len(others)})-------\n"
text_keys1 = sorted(text_blocks_map1.keys())
for x in text_keys1:
text += f" TXT_ENC{x}: {len(text_blocks_map1[x])}\n"

text_keys2 = sorted(text_blocks_map2.keys())
for x in text_keys2:
text += f" TXT_ENC{x} [B]: {len(text_blocks_map2[x])}\n"

for x in others:
text += f" {x}\n"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-inspire-pack"
description = "This extension provides various nodes to support Lora Block Weight and the Impact Pack. Provides many easily applicable regional features and applications for Variation Seed."
version = "1.6"
version = "1.6.1"
license = { file = "LICENSE" }
dependencies = ["matplotlib", "cachetools"]

Expand Down

0 comments on commit 3bcc1d4

Please sign in to comment.