From 042e1d5d0b1fc0bfd358e3a90db7d163934bd238 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 15:00:14 +0900 Subject: [PATCH 1/5] Fix SD VAE switch error after model reuse --- modules/sd_models.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 685585b1cbc..2c976561e6d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -462,6 +462,7 @@ class SdModelData: def __init__(self): self.sd_model = None self.loaded_sd_models = [] + self.loaded_vae_states = {} self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -485,16 +486,27 @@ def get_sd_model(self): return self.sd_model - def set_sd_model(self, v): + def set_sd_model(self, v, already_loaded=False): self.sd_model = v + if already_loaded: + sd_vae_state = self.loaded_vae_states.get(v.sd_model_hash, {}) + sd_vae.base_vae = sd_vae_state.get("base_vae", None) + sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) + sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) try: self.loaded_sd_models.remove(v) + self.loaded_vae_states.pop(v.sd_model_hash, {}).clear() except ValueError: pass if v is not None: self.loaded_sd_models.insert(0, v) + self.loaded_vae_states[v.sd_model_hash] = dict( + base_vae=sd_vae.base_vae, + loaded_vae_file=sd_vae.loaded_vae_file, + checkpoint_info=sd_vae.checkpoint_info, + ) model_data = SdModelData() @@ -649,6 +661,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") model_data.loaded_sd_models.pop() + model_data.loaded_vae_states.pop(loaded_model.sd_model_hash, {}).clear() send_model_to_trash(loaded_model) timer.record("send model to trash") @@ -660,7 +673,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): send_model_to_device(already_loaded) timer.record("send model to device") - model_data.set_sd_model(already_loaded) + model_data.set_sd_model(already_loaded, already_loaded=True) if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title @@ -678,6 +691,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop() model_data.sd_model = sd_model + sd_vae_state = model_data.loaded_vae_states.pop(sd_model.sd_model_hash, {}) + sd_vae.base_vae = sd_vae_state.get("base_vae", None) + sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) + sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") return sd_model else: From 5159edbf0e0e1d5a25fbd588e000487746790117 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:44:37 +0900 Subject: [PATCH 2/5] Store base_vae and loaded_vae_file in sd_model --- modules/sd_models.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2c976561e6d..150d550b66b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -462,7 +462,6 @@ class SdModelData: def __init__(self): self.sd_model = None self.loaded_sd_models = [] - self.loaded_vae_states = {} self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -489,24 +488,19 @@ def get_sd_model(self): def set_sd_model(self, v, already_loaded=False): self.sd_model = v if already_loaded: - sd_vae_state = self.loaded_vae_states.get(v.sd_model_hash, {}) - sd_vae.base_vae = sd_vae_state.get("base_vae", None) - sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) - sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + sd_vae.base_vae = getattr(v, "base_vae", None) + sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None) + sd_vae.checkpoint_info = v.sd_checkpoint_info try: self.loaded_sd_models.remove(v) - self.loaded_vae_states.pop(v.sd_model_hash, {}).clear() except ValueError: pass if v is not None: + setattr(v, "base_vae", sd_vae.base_vae) + setattr(v, "loaded_vae_file", sd_vae.loaded_vae_file) self.loaded_sd_models.insert(0, v) - self.loaded_vae_states[v.sd_model_hash] = dict( - base_vae=sd_vae.base_vae, - loaded_vae_file=sd_vae.loaded_vae_file, - checkpoint_info=sd_vae.checkpoint_info, - ) model_data = SdModelData() @@ -661,7 +655,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") model_data.loaded_sd_models.pop() - model_data.loaded_vae_states.pop(loaded_model.sd_model_hash, {}).clear() send_model_to_trash(loaded_model) timer.record("send model to trash") @@ -691,10 +684,9 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop() model_data.sd_model = sd_model - sd_vae_state = model_data.loaded_vae_states.pop(sd_model.sd_model_hash, {}) - sd_vae.base_vae = sd_vae_state.get("base_vae", None) - sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) - sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + sd_vae.base_vae = getattr(sd_model, "base_vae", None) + sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None) + sd_vae.checkpoint_info = sd_model.sd_checkpoint_info print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") return sd_model From af5d2e8e5fc4440691fb7f1aa3492def1c755722 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 20:08:22 +0900 Subject: [PATCH 3/5] Change to access sd_model attribute with dot --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 150d550b66b..dd749122ce2 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -498,8 +498,8 @@ def set_sd_model(self, v, already_loaded=False): pass if v is not None: - setattr(v, "base_vae", sd_vae.base_vae) - setattr(v, "loaded_vae_file", sd_vae.loaded_vae_file) + v.base_vae = sd_vae.base_vae + v.loaded_vae_file = sd_vae.loaded_vae_file self.loaded_sd_models.insert(0, v) From 549b0fc5267e9539f321f0891aa757619b7248cb Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 23:06:51 +0900 Subject: [PATCH 4/5] Change where VAE state are stored in model --- modules/sd_models.py | 2 -- modules/sd_vae.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index dd749122ce2..d3775ec664d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -498,8 +498,6 @@ def set_sd_model(self, v, already_loaded=False): pass if v is not None: - v.base_vae = sd_vae.base_vae - v.loaded_vae_file = sd_vae.loaded_vae_file self.loaded_sd_models.insert(0, v) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index dbade06794a..ee118656860 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -192,7 +192,7 @@ def load_vae_dict(filename, map_location): def load_vae(model, vae_file=None, vae_source="from unknown source"): - global vae_dict, loaded_vae_file + global vae_dict, base_vae, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -230,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): restore_base_vae(model) loaded_vae_file = vae_file + model.base_vae = base_vae + model.loaded_vae_file = loaded_vae_file # don't call this from outside From be301f224d26ac4363ce3bd8bcb510b00bd6db27 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Mon, 21 Aug 2023 11:28:53 +0900 Subject: [PATCH 5/5] Fix for consistency with shared.opts.sd_vae of UI --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index d3775ec664d..27d15e66007 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -671,6 +671,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + sd_vae.reload_vae_weights(already_loaded) return model_data.sd_model elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")