From 7de0a2ada7ad4765ea83813e757bdc60b2a16031 Mon Sep 17 00:00:00 2001 From: CastielMa Date: Sat, 25 May 2024 23:10:14 -0400 Subject: [PATCH] thread safe extra network list_items #13014 --- extensions-builtin/Lora/ui_extra_networks_lora.py | 6 +++++- gcloud_deploy.sh | 1 + modules/ui_extra_networks_checkpoints.py | 6 +++++- modules/ui_extra_networks_hypernets.py | 12 +++++++++--- modules/ui_extra_networks_textual_inversion.py | 10 +++++++--- 5 files changed, 27 insertions(+), 8 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 55409a7829d..13e78f0e1dc 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -17,6 +17,8 @@ def refresh(self): def create_item(self, name, index=None, enable_filter=True): lora_on_disk = networks.available_networks.get(name) + if lora_on_disk is None: + return path, ext = os.path.splitext(lora_on_disk.filename) @@ -66,7 +68,9 @@ def create_item(self, name, index=None, enable_filter=True): return item def list_items(self): - for index, name in enumerate(networks.available_networks): + # instantiate a list to protect against concurrent modification + names = list(networks.available_networks) + for index, name in enumerate(names): item = self.create_item(name, index) if item is not None: diff --git a/gcloud_deploy.sh b/gcloud_deploy.sh index c2d728755ca..a58056b3703 100644 --- a/gcloud_deploy.sh +++ b/gcloud_deploy.sh @@ -1,5 +1,6 @@ # install +sudo apt install nvidia-cuda-toolkit sudo apt-get update sudo apt -y install wget git python3 diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ca6c26076f9..f535448cb5b 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -15,6 +15,8 @@ def refresh(self): def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) + if checkpoint is None: + return path, ext = os.path.splitext(checkpoint.filename) return { "name": checkpoint.name_for_extra, @@ -32,7 +34,9 @@ def create_item(self, name, index=None, enable_filter=True): def list_items(self): names = list(sd_models.checkpoints_list) for index, name in enumerate(names): - yield self.create_item(name, index) + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 4cedf085196..97af697cbbc 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -13,7 +13,9 @@ def refresh(self): shared.reload_hypernetworks() def create_item(self, name, index=None, enable_filter=True): - full_path = shared.hypernetworks[name] + full_path = shared.hypernetworks.get(name) + if full_path is None: + return path, ext = os.path.splitext(full_path) sha256 = sha256_from_cache(full_path, f'hypernet/{name}') shorthash = sha256[0:10] if sha256 else None @@ -31,8 +33,12 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - for index, name in enumerate(shared.hypernetworks): - yield self.create_item(name, index) + names = list(shared.hypernetworks) + for index, name in enumerate(names): + item = self.create_item(name, index) + if item is not None: + yield item + def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 55ef0ea7b54..087ee8ca594 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -14,7 +14,8 @@ def refresh(self): def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) - + if embedding is None: + return path, ext = os.path.splitext(embedding.filename) return { "name": name, @@ -29,8 +30,11 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): - yield self.create_item(name, index) + names = list(sd_hijack.model_hijack.embedding_db.word_embeddings) + for index, name in enumerate(names): + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)