Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support and automatically detect SDXL V-prediction models #16567

Merged
merged 1 commit into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions configs/sd_xl_v.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True

denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000

weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False

conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two

first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
7 changes: 4 additions & 3 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls)


def load_model(checkpoint_info=None, already_loaded_state_dict=None):
def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_config=None):
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()

Expand All @@ -801,7 +801,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
if not checkpoint_config:
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)

timer.record("find config")
Expand Down Expand Up @@ -974,7 +975,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
if sd_model is not None:
send_model_to_trash(sd_model)

load_model(checkpoint_info, already_loaded_state_dict=state_dict)
load_model(checkpoint_info, already_loaded_state_dict=state_dict, checkpoint_config=checkpoint_config)
return model_data.sd_model

try:
Expand Down
4 changes: 4 additions & 0 deletions modules/sd_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxlv = os.path.join(sd_configs_path, "sd_xl_v.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
Expand Down Expand Up @@ -81,6 +82,9 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting
else:
if ('v_pred' in sd):
del sd['v_pred']
return config_sdxlv
return config_sdxl

if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
Expand Down