From f84e6d4eebdbcae32d7409e53e4f143ccb973cd1 Mon Sep 17 00:00:00 2001 From: allenbenz Date: Tue, 14 Feb 2023 19:22:05 -0800 Subject: [PATCH 1/8] Make it so that the model learns to change the zero-frequency of the component freely, because that component is now being randomized ~10 times faster than for the base distribution https://www.crosslabs.org/blog/diffusion-with-offset-noise --- scripts/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/trainer.py b/scripts/trainer.py index bf5a572..8e7b1de 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -2721,7 +2721,8 @@ def help(event=None): if args.sample_from_batch > 0: args.batch_tokens = batch[0][5] # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + # and some extra bits to allow better learning of strong contrasts + noise = torch.randn_like(latents) + (0.1 * torch.randn(latents.shape[0], latents.shape[1], 1, 1).to(accelerator.device)) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, int(noise_scheduler.config.num_train_timesteps * args.max_denoising_strength), (bsz,), device=latents.device) From 8d356ceb0867364816b502ce9dd0e16318817f5c Mon Sep 17 00:00:00 2001 From: allenbenz Date: Thu, 16 Feb 2023 20:51:36 -0800 Subject: [PATCH 2/8] Add offset noise flag and weight control to the gui. --- scripts/configuration_gui.py | 37 +++++++++++++++++++++++++++++++++++- scripts/trainer.py | 17 +++++++++++++++-- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index e42d390..7a35bd4 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -888,6 +888,8 @@ def create_default_variables(self): self.save_sample_controlled_seed = [] self.delete_checkpoints_when_full_drive = True self.use_image_names_as_captions = True + self.use_offset_noise = False + self.offset_noise_weight = 0.1 self.num_samples_to_generate = 1 self.auto_balance_concept_datasets = False self.sample_width = 512 @@ -1571,6 +1573,24 @@ def create_trainer_settings_widgets(self): self.prior_loss_preservation_weight_entry = ctk.CTkEntry(self.training_frame_subframe) self.prior_loss_preservation_weight_entry.grid(row=19, column=3, sticky="w") self.prior_loss_preservation_weight_entry.insert(0, self.prior_loss_weight) + + #create contrasting light and color checkbox + self.use_offset_noise_var = tk.IntVar() + self.use_offset_noise_var.set(self.use_offset_noise) + #create label + self.offset_noise_label = ctk.CTkLabel(self.training_frame_subframe, text="With Offset Noise") + offset_noise_label_ttp = CreateToolTip(self.offset_noise_label, "Apply offset noise to latents to learn image contrast.") + self.offset_noise_label.grid(row=20, column=0, sticky="nsew") + #create checkbox + self.offset_noise_checkbox = ctk.CTkSwitch(self.training_frame_subframe, variable=self.use_offset_noise_var) + self.offset_noise_checkbox.grid(row=20, column=1, sticky="nsew") + #create prior loss preservation weight entry + self.offset_noise_weight_label = ctk.CTkLabel(self.training_frame_subframe, text="Offset Noise Weight") + offset_noise_weight_label_ttp = CreateToolTip(self.offset_noise_weight_label, "The weight of the offset noise.") + self.offset_noise_weight_label.grid(row=20, column=1, sticky="e") + self.offset_noise_weight_entry = ctk.CTkEntry(self.training_frame_subframe) + self.offset_noise_weight_entry.grid(row=20, column=3, sticky="w") + self.offset_noise_weight_entry.insert(0, self.offset_noise_weight) def create_dataset_settings_widgets(self): @@ -1957,7 +1977,7 @@ def create_plyaground_widgets(self): self.play_negative_prompt_entry.bind("", lambda event: self.play_generate_image(self.play_model_entry.get(), self.play_prompt_entry.get(), self.play_negative_prompt_entry.get(), self.play_seed_entry.get(), self.play_scheduler_variable.get(), int(self.play_resolution_slider_height.get()), int(self.play_resolution_slider_width.get()), self.play_cfg_slider.get(), self.play_steps_slider.get())) #add convert to ckpt button - self.play_convert_to_ckpt_button = ctk.CTkButton(self.playground_frame_subframe, text="Convert To CKPT", command=lambda:self.convert_to_ckpt(model_path=self.play_model_entry.get())) + self.play_convert_to_ckpt_button = ctk.CTkButton(self.playground_frame_subframe, text="Convert To CKPT", command=lambda:self.convert_to_safetensors(model_path=self.play_model_entry.get())) #add interative generation button to act as a toggle #convert to safetensors button @@ -3070,6 +3090,8 @@ def save_config(self, config_file=None): configure['attention'] = self.attention_var.get() configure['batch_prompt_sampling'] = int(self.batch_prompt_sampling_optionmenu_var.get()) configure['shuffle_dataset_per_epoch'] = self.shuffle_dataset_per_epoch_var.get() + configure['use_offset_noise'] = self.use_offset_noise_var.get() + configure['offset_noise_weight'] = self.offset_noise_weight_entry.get() #save the configure file #if the file exists, delete it if os.path.exists(file_name): @@ -3222,6 +3244,9 @@ def load_config(self,file_name=None): self.attention_var.set(configure["attention"]) self.batch_prompt_sampling_optionmenu_var.set(str(configure['batch_prompt_sampling'])) self.shuffle_dataset_per_epoch_var.set(configure["shuffle_dataset_per_epoch"]) + self.use_offset_noise_var.set(configure["use_offset_noise"]) + self.offset_noise_weight_entry.delete(0, tk.END) + self.offset_noise_weight_entry.insert(0, configure["offset_noise_weight"]) self.update() def process_inputs(self,export=None): @@ -3291,6 +3316,9 @@ def process_inputs(self,export=None): self.attention = self.attention_var.get() self.batch_prompt_sampling = int(self.batch_prompt_sampling_optionmenu_var.get()) self.shuffle_dataset_per_epoch = self.shuffle_dataset_per_epoch_var.get() + self.use_offset_noise = self.use_offset_noise_var.get() + self.offset_noise_weight = self.offset_noise_weight_entry.get() + mode = 'normal' if self.cloud_mode == False and export == None: #check if output path exists @@ -3579,6 +3607,13 @@ def process_inputs(self,export=None): batBase += ' --use_image_names_as_captions' else: batBase += f' "--use_image_names_as_captions" ' + if self.use_offset_noise == True: + if export == 'Linux': + batBase += f' --with_offset_noise' + batBase += f' --offset_noise_weight={self.offset_noise_weight}' + else: + batBase += f' "--with_offset_noise" ' + batBase += f' "--offset_noise_weight={self.offset_noise_weight}" ' if self.auto_balance_concept_datasets == True: if export == 'Linux': batBase += ' --auto_balance_concept_datasets' diff --git a/scripts/trainer.py b/scripts/trainer.py index 8e7b1de..ae57a84 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -254,6 +254,14 @@ def parse_args(): help="Flag to add prior preservation loss.", ) parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--with_offset_noise", + default=False, + action="store_true", + help="Flag to offset noise applied to latents.", + ) + + parser.add_argument("--offset_noise_weight", type=float, default=0.1, help="The weight of offset noise applied during training.") parser.add_argument( "--num_class_images", type=int, @@ -2721,8 +2729,13 @@ def help(event=None): if args.sample_from_batch > 0: args.batch_tokens = batch[0][5] # Sample noise that we'll add to the latents - # and some extra bits to allow better learning of strong contrasts - noise = torch.randn_like(latents) + (0.1 * torch.randn(latents.shape[0], latents.shape[1], 1, 1).to(accelerator.device)) + # and some extra bits to make it so that the model learns to change the zero-frequency of the component freely + # https://www.crosslabs.org/blog/diffusion-with-offset-noise + if (args.with_offset_noise == True): + noise = torch.randn_like(latents) + (args.offset_noise_weight * torch.randn(latents.shape[0], latents.shape[1], 1, 1).to(accelerator.device)) + else: + noise = torch.randn_like(latents) + bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, int(noise_scheduler.config.num_train_timesteps * args.max_denoising_strength), (bsz,), device=latents.device) From ece7071f6926e03aebb445577262be5e694b7a0a Mon Sep 17 00:00:00 2001 From: Buckzor Date: Sat, 18 Feb 2023 13:56:43 +0000 Subject: [PATCH 3/8] Update .gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index d062eb8..309c70d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ - +out/ +training/ cudnn_windows/cudnn64_8.dll cudnn_windows/cudnn_adv_infer64_8.dll cudnn_windows/cudnn_adv_train64_8.dll From 4cfe801a18cf09e9c131d66add3352d98bf61301 Mon Sep 17 00:00:00 2001 From: Buckzor Date: Wed, 22 Feb 2023 19:51:49 +0000 Subject: [PATCH 4/8] Add concept repeats --- scripts/configuration_gui.py | 48 +++++++++---- scripts/dataloaders_util.py | 133 +++++++++++++++++++++++------------ 2 files changed, 123 insertions(+), 58 deletions(-) diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index 731a5e7..b2e451a 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -75,8 +75,9 @@ def __init__(self, parent, concept=None,width=150,height=150, *args, **kwargs): self.concept_do_not_balance = False self.process_sub_dirs = False self.image_preview = self.default_image_preview + self.repeat_concept = '1' #create concept - self.concept = Concept(self.concept_name, self.concept_data_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, None) + self.concept = Concept(self.concept_name, self.concept_data_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, None, self.repeat_concept) else: self.concept = concept self.concept.image_preview = self.make_image_preview() @@ -274,6 +275,14 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): if self.concept.flip_p != '': self.flip_probability_entry.insert(0, self.concept.flip_p) #self.flip_probability_entry.bind("", self.create_right_click_menu) + + #entry and label for concept repeats + self.repeat_concept_label = ctk.CTkLabel(self.concept_frame_subframe, text="Repeat Concept:") + self.repeat_concept_label.grid(row=4, column=0, sticky="nsew",padx=5,pady=5) + self.repeat_concept_entry = ctk.CTkEntry(self.concept_frame_subframe,width=200,placeholder_text="1") + self.repeat_concept_entry.grid(row=4, column=1, sticky="e",padx=5,pady=5) + if self.concept.repeat_concept != '': + self.repeat_concept_entry.insert(1, self.concept.repeat_concept) #make a label for dataset balancingprocess_sub_dirs self.balance_dataset_label = ctk.CTkLabel(self.concept_frame_subframe, text="Don't Balance Dataset") @@ -452,6 +461,8 @@ def save(self): flip_p = self.flip_probability_entry.get() #get the dataset balancing balance_dataset = self.balance_dataset_switch.get() + #get concept repeats + repeat_concept = self.repeat_concept_entry.get() #create the concept process_sub_dirs = self.process_sub_dirs_switch.get() #image preview @@ -459,14 +470,14 @@ def save(self): #get the main window image_preview_label = self.image_preview_label #update the concept - self.concept.update(concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs,image_preview,image_preview_label) + self.concept.update(concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs,image_preview,image_preview_label,repeat_concept) self.conceptWidget.update_button() #close the window self.destroy() #class of the concept class Concept: - def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, balance_dataset=None,process_sub_dirs=None,image_preview=None, image_container=None): + def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, balance_dataset=None,process_sub_dirs=None,image_preview=None, image_container=None,repeat_concept=1): if concept_name == None: concept_name = "" if concept_path == None: @@ -477,6 +488,8 @@ def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, ba class_path = "" if flip_p == None: flip_p = "" + if repeat_concept == None: + repeat_concept = "1" if balance_dataset == None: balance_dataset = False if process_sub_dirs == None: @@ -496,8 +509,9 @@ def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, ba self.image_preview = image_preview self.image_container = image_container self.process_sub_dirs = process_sub_dirs + self.repeat_concept = repeat_concept #update the concept - def update(self, concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs, image_preview, image_container): + def update(self, concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs, image_preview, image_container,repeat_concept): self.concept_name = concept_name self.concept_path = concept_path self.concept_class_name = class_name @@ -509,9 +523,10 @@ def update(self, concept_name, concept_path, class_name, class_path,flip_p,balan self.image_preview = image_preview self.image_container = image_container self.process_sub_dirs = process_sub_dirs + self.repeat_concept = repeat_concept #get the cocept details def get_details(self): - return self.concept_name, self.concept_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, self.image_container + return self.concept_name, self.concept_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, self.image_container, self.repeat_concept #class to make popup right click menu with select all, copy, paste, cut, and delete when right clicked on an entry box class DynamicGrid(ctk.CTkFrame): def __init__(self, parent, *args, **kwargs): @@ -1578,17 +1593,17 @@ def create_trainer_settings_widgets(self): self.use_offset_noise_var = tk.IntVar() self.use_offset_noise_var.set(self.use_offset_noise) #create label - self.offset_noise_label = ctk.CTkLabel(self.training_frame_subframe, text="With Offset Noise") + self.offset_noise_label = ctk.CTkLabel(self.training_frame_finetune_subframe, text="With Offset Noise") offset_noise_label_ttp = CreateToolTip(self.offset_noise_label, "Apply offset noise to latents to learn image contrast.") self.offset_noise_label.grid(row=20, column=0, sticky="nsew") #create checkbox - self.offset_noise_checkbox = ctk.CTkSwitch(self.training_frame_subframe, variable=self.use_offset_noise_var) + self.offset_noise_checkbox = ctk.CTkSwitch(self.training_frame_finetune_subframe, variable=self.use_offset_noise_var) self.offset_noise_checkbox.grid(row=20, column=1, sticky="nsew") #create prior loss preservation weight entry - self.offset_noise_weight_label = ctk.CTkLabel(self.training_frame_subframe, text="Offset Noise Weight") + self.offset_noise_weight_label = ctk.CTkLabel(self.training_frame_finetune_subframe, text="Offset Noise Weight") offset_noise_weight_label_ttp = CreateToolTip(self.offset_noise_weight_label, "The weight of the offset noise.") self.offset_noise_weight_label.grid(row=20, column=1, sticky="e") - self.offset_noise_weight_entry = ctk.CTkEntry(self.training_frame_subframe) + self.offset_noise_weight_entry = ctk.CTkEntry(self.training_frame_finetune_subframe) self.offset_noise_weight_entry.grid(row=20, column=3, sticky="w") self.offset_noise_weight_entry.insert(0, self.offset_noise_weight) @@ -2188,6 +2203,7 @@ def packageForCloud(self): new_concept['class_data_dir'] = 'datasets' + '/' + concept_class_name if concept_class_name != '' else '' new_concept['do_not_balance'] = concept['do_not_balance'] new_concept['use_sub_dirs'] = concept['use_sub_dirs'] + new_concept['repeat_concept'] = concept['repeat_concept'] new_concepts.append(new_concept) #make scripts folder self.save_concept_to_json(filename=self.full_export_path + os.sep + 'stabletune_concept_list.json', preMadeConcepts=new_concepts) @@ -2843,7 +2859,7 @@ def save_concept_to_json(self,filename=None,preMadeConcepts=None): concepts = [] for widget in self.concept_widgets: concept = widget.concept - concept_dict = {'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs} + concept_dict = {'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs, 'repeat_concept' : concept.repeat_concept} concepts.append(concept_dict) if file != None: #write the json to the file @@ -2869,7 +2885,9 @@ def load_concept_from_json(self): #print(concept) if 'flip_p' not in concept: concept['flip_p'] = '' - concept = Concept(concept_name=concept["instance_prompt"], class_name=concept["class_prompt"], concept_path=concept["instance_data_dir"], class_path=concept["class_data_dir"],flip_p=concept['flip_p'],balance_dataset=concept["do_not_balance"], process_sub_dirs=concept["use_sub_dirs"]) + if 'repeat_concept' not in concept: + concept['repeat_concept'] = '1' + concept = Concept(concept_name=concept["instance_prompt"], class_name=concept["class_prompt"], concept_path=concept["instance_data_dir"], class_path=concept["class_data_dir"],flip_p=concept['flip_p'],balance_dataset=concept["do_not_balance"], process_sub_dirs=concept["use_sub_dirs"], repeat_concept=concept["repeat_concept"]) self.add_new_concept(concept) #self.canvas.configure(scrollregion=self.canvas.bbox("all")) self.update() return concept_json @@ -3012,7 +3030,7 @@ def update_concepts(self): self.concepts = [] for i in range(len(self.concept_widgets)): concept = self.concept_widgets[i].concept - self.concepts.append({'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs}) + self.concepts.append({'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs, 'repeat_concept' : concept.repeat_concept}) def save_config(self, config_file=None): #save the configure file import json @@ -3128,7 +3146,11 @@ def load_config(self,file_name=None): flip_p = configure["concepts"][i]["flip_p"] balance_dataset = configure["concepts"][i]["do_not_balance"] process_sub_dirs = configure["concepts"][i]["use_sub_dirs"] - concept = Concept(concept_name=inst_prompt, class_name=class_prompt, concept_path=inst_data_dir, class_path=class_data_dir,flip_p=flip_p,balance_dataset=balance_dataset,process_sub_dirs=process_sub_dirs) + if 'repeat_concept' not in configure["concepts"][i]: + print(configure["concepts"][i].keys()) + configure["concepts"][i]['repeat_concept'] = '1' + repeat_concept = configure["concepts"][i]["repeat_concept"] + concept = Concept(concept_name=inst_prompt, class_name=class_prompt, concept_path=inst_data_dir, class_path=class_data_dir,flip_p=flip_p,balance_dataset=balance_dataset,process_sub_dirs=process_sub_dirs,repeat_concept=repeat_concept) self.add_new_concept(concept) except Exception as e: print(e) diff --git a/scripts/dataloaders_util.py b/scripts/dataloaders_util.py index 0b64367..10caa65 100644 --- a/scripts/dataloaders_util.py +++ b/scripts/dataloaders_util.py @@ -736,6 +736,7 @@ class DataLoaderMultiAspect(): data_root: root folder of training data batch_size: number of images per batch flip_p: probability of flipping image horizontally (i.e. 0-0.5) + epeat_concept: How many times to repeat each concept in the dataset """ def __init__( self, @@ -756,6 +757,7 @@ def __init__( extra_module=None, mask_prompts=None, load_mask=False, + repeat_concept=1, ): self.resolution = resolution self.debug_level = debug_level @@ -778,22 +780,40 @@ def __init__( #process sub directories flag print(f" {bcolors.WARNING} Preloading images...{bcolors.ENDC}") + + #Get concept repeat count + for concept in concept_list: + if 'repeat_concept' in concept.keys(): + repeat_concept = concept['repeat_concept'] + if repeat_concept == '': + repeat_concept = 1 + else: + repeat_concept = int(repeat_concept) + if repeat_concept > 1: + print(f" {bcolors.WARNING} Repeating concept {concept['instance_data_dir']} {repeat_concept} times...{bcolors.ENDC}") if balance_datasets: print(f" {bcolors.WARNING} Balancing datasets...{bcolors.ENDC}") #get the concept with the least number of images in instance_data_dir for concept in concept_list: count = 0 + if 'repeat_concept' in concept.keys(): + repeat_concept = concept['repeat_concept'] + if repeat_concept == '': + repeat_concept = 1 + else: + repeat_concept = int(repeat_concept) + if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == 1: tot = 0 for root, dirs, files in os.walk(concept['instance_data_dir']): tot += len(files) - count = tot + count = tot*repeat_concept else: - count = len(os.listdir(concept['instance_data_dir'])) + count = len(os.listdir(concept['instance_data_dir']))*repeat_concept else: - count = len(os.listdir(concept['instance_data_dir'])) + count = len(os.listdir(concept['instance_data_dir']))*repeat_concept print(f"{concept['instance_data_dir']} has count of {count}") concept['count'] = count @@ -802,16 +822,16 @@ def __init__( min_concept_num_images = min_concept['count'] print(" Min concept: ",min_concept['instance_data_dir']," with ",min_concept_num_images," images") - balance_cocnept_list = [] + balance_concept_list = [] for concept in concept_list: #if concept has a key do not balance it if 'do_not_balance' in concept: if concept['do_not_balance'] == True: - balance_cocnept_list.append(-1) + balance_concept_list.append(-1) else: - balance_cocnept_list.append(min_concept_num_images) + balance_concept_list.append(min_concept_num_images) else: - balance_cocnept_list.append(min_concept_num_images) + balance_concept_list.append(min_concept_num_images) for concept in concept_list: if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == True: @@ -824,7 +844,7 @@ def __init__( #self.class_image_paths = [] min_concept_num_images = None if balance_datasets: - min_concept_num_images = balance_cocnept_list[concept_list.index(concept)] + min_concept_num_images = balance_concept_list[concept_list.index(concept)] data_root = concept['instance_data_dir'] data_root_class = concept['class_data_dir'] concept_prompt = concept['instance_prompt'] @@ -835,18 +855,26 @@ def __init__( flip_p = 0.0 else: flip_p = float(flip_p) + + if 'repeat_concept' in concept.keys(): + repeat_concept = concept['repeat_concept'] + if repeat_concept == '': + repeat_concept = 1 + else: + repeat_concept = int(repeat_concept) + self.__recurse_data_root(self=self, recurse_root=data_root,use_sub_dirs=use_sub_dirs) random.Random(self.seed).shuffle(self.image_paths) if self.model_variant == 'depth2img': print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}") self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) - prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[] + prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,repeat_concept,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[] if add_class_images_to_dataset: self.image_paths = [] self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs) random.Random(self.seed).shuffle(self.image_paths) use_image_names_as_captions = False - prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[] + prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,repeat_concept,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[] self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) if self.with_prior_loss and add_class_images_to_dataset == False: @@ -861,7 +889,7 @@ def __init__( print(f" {bcolors.WARNING} ** Depth2Img To Process Class Dataset{bcolors.ENDC}") self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) use_image_names_as_captions = False - self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) + self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,repeat_concept,use_text_files_as_captions=self.use_text_files_as_captions)) self.class_image_caption_pairs = self.__bucketize_images(self.class_image_caption_pairs, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) if mask_prompts is not None: print(f" {bcolors.WARNING} Checking and generating missing masks...{bcolors.ENDC}") @@ -878,43 +906,43 @@ def get_all_images(self): return self.image_caption_pairs else: return self.image_caption_pairs, self.class_image_caption_pairs - def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,use_text_files_as_captions=False): + def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,repeat_concept=1,use_text_files_as_captions=False): """ Create ImageTrainItem objects with metadata for hydration later """ decorated_image_train_items = [] - - for pathname in image_paths: - identifier = concept - if use_image_names_as_captions: - caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] - identifier = caption_from_filename - if use_text_files_as_captions: - txt_file_path = os.path.splitext(pathname)[0] + ".txt" - - if os.path.exists(txt_file_path): - try: - with open(txt_file_path, 'r',encoding='utf-8',errors='ignore') as f: - identifier = f.readline().rstrip() - f.close() - if len(identifier) < 1: - raise ValueError(f" *** Could not find valid text in: {txt_file_path}") - - except Exception as e: - print(f" {bcolors.FAIL} *** Error reading {txt_file_path} to get caption, falling back to filename{bcolors.ENDC}") - print(e) - identifier = caption_from_filename - pass - #print("identifier: ",identifier) - image = Image.open(pathname) - width, height = image.size - image_aspect = width / height + for i in range(repeat_concept): + for pathname in image_paths: + identifier = concept + if use_image_names_as_captions: + caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] + identifier = caption_from_filename + if use_text_files_as_captions: + txt_file_path = os.path.splitext(pathname)[0] + ".txt" + + if os.path.exists(txt_file_path): + try: + with open(txt_file_path, 'r',encoding='utf-8',errors='ignore') as f: + identifier = f.readline().rstrip() + f.close() + if len(identifier) < 1: + raise ValueError(f" *** Could not find valid text in: {txt_file_path}") + + except Exception as e: + print(f" {bcolors.FAIL} *** Error reading {txt_file_path} to get caption, falling back to filename{bcolors.ENDC}") + print(e) + identifier = caption_from_filename + pass + #print("identifier: ",identifier) + image = Image.open(pathname) + width, height = image.size + image_aspect = width / height - target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) + target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask) + image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask) - decorated_image_train_items.append(image_train_item) + decorated_image_train_items.append(image_train_item) return decorated_image_train_items @staticmethod @@ -1058,6 +1086,7 @@ def __init__( extra_module=None, mask_prompts=None, load_mask=None, + repeat_concept=1, ): self.use_image_names_as_captions = use_image_names_as_captions self.size = size @@ -1072,7 +1101,19 @@ def __init__( self.variant_warning = False self.vae_scale_factor = None self.load_mask = load_mask + self.repeat_concept = repeat_concept for concept in concepts_list: + + #Get concept repeat count + if 'repeat_concept' in concept.keys(): + repeat_concept = concept['repeat_concept'] + if repeat_concept == '': + repeat_concept = 1 + else: + repeat_concept = int(repeat_concept) + if repeat_concept > 1: + print(f" {bcolors.WARNING} Repeating concept {concept['instance_data_dir']} {repeat_concept} times...{bcolors.ENDC}") + if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == True: use_sub_dirs = True @@ -1081,12 +1122,14 @@ def __init__( else: use_sub_dirs = False - for i in range(repeats): - self.__recurse_data_root(self, concept,use_sub_dirs=use_sub_dirs) + for i in range(repeat_concept): + for i in range(repeats): + self.__recurse_data_root(self, concept,use_sub_dirs=use_sub_dirs) if with_prior_preservation: - for i in range(repeats): - self.__recurse_data_root(self, concept,use_sub_dirs=False,class_images=True) + for i in range(repeat_concept): + for i in range(repeats): + self.__recurse_data_root(self, concept,use_sub_dirs=False,class_images=True) if mask_prompts is not None: print(f" {bcolors.WARNING} Checking and generating missing masks{bcolors.ENDC}") clip_seg = ClipSeg() From 7e34682c79e09ba740a2ecae49cfa24691b45099 Mon Sep 17 00:00:00 2001 From: Buckzor Date: Wed, 22 Feb 2023 20:18:28 +0000 Subject: [PATCH 5/8] Fixing bug introduced by my branch --- scripts/configuration_gui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index b2e451a..e8d9cc0 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -1992,7 +1992,7 @@ def create_plyaground_widgets(self): self.play_negative_prompt_entry.bind("", lambda event: self.play_generate_image(self.play_model_entry.get(), self.play_prompt_entry.get(), self.play_negative_prompt_entry.get(), self.play_seed_entry.get(), self.play_scheduler_variable.get(), int(self.play_resolution_slider_height.get()), int(self.play_resolution_slider_width.get()), self.play_cfg_slider.get(), self.play_steps_slider.get())) #add convert to ckpt button - self.play_convert_to_ckpt_button = ctk.CTkButton(self.playground_frame_subframe, text="Convert To CKPT", command=lambda:self.convert_to_safetensors(model_path=self.play_model_entry.get())) + self.play_convert_to_ckpt_button = ctk.CTkButton(self.playground_frame_subframe, text="Convert To CKPT", command=lambda:self.convert_to_ckpt(model_path=self.play_model_entry.get())) #add interative generation button to act as a toggle #convert to safetensors button From 3cdda001ee49ce8b43b298b37cbbac374606fe58 Mon Sep 17 00:00:00 2001 From: Buckzor Date: Wed, 22 Feb 2023 20:18:54 +0000 Subject: [PATCH 6/8] Added repeat_concept into missing shared_dataloader --- scripts/dataloaders_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/dataloaders_util.py b/scripts/dataloaders_util.py index 10caa65..ebbf659 100644 --- a/scripts/dataloaders_util.py +++ b/scripts/dataloaders_util.py @@ -333,6 +333,7 @@ def __init__(self, extra_module=None, mask_prompts=None, load_mask=False, + repeat_concept=1, ): self.debug_level = debug_level @@ -392,6 +393,7 @@ def __init__(self, extra_module=self.extra_module, mask_prompts=mask_prompts, load_mask=load_mask, + repeat_concept=repeat_concept, ) #print(self.image_train_items) From cd9e4d554089a41bef5951d68946a4debbcb7b99 Mon Sep 17 00:00:00 2001 From: Buckzor Date: Thu, 23 Feb 2023 10:25:49 +0000 Subject: [PATCH 7/8] Fix repeat flip_p box not being visible --- scripts/configuration_gui.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index e8d9cc0..4ee2dff 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -204,7 +204,7 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): self.parent = parent self.conceptWidget = conceptWidget self.concept = concept - self.geometry("576x297") + self.geometry("576x327") self.resizable(False, False) #self.protocol("WM_DELETE_WINDOW", self.on_close) self.wait_visibility() @@ -214,9 +214,9 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): #self.default_image_preview = ImageTk.PhotoImage(self.default_image_preview) #make a frame for the concept window - self.concept_frame = ctk.CTkFrame(self, width=600, height=300) + self.concept_frame = ctk.CTkFrame(self, width=600, height=320) self.concept_frame.grid(row=0, column=0, sticky="nsew",padx=10,pady=10) - self.concept_frame_subframe=ctk.CTkFrame(self.concept_frame, width=600, height=300) + self.concept_frame_subframe=ctk.CTkFrame(self.concept_frame, width=600, height=320) #4 column grid #self.concept_frame.grid_columnconfigure(0, weight=1) #self.concept_frame.grid_columnconfigure(1, weight=5) @@ -278,32 +278,32 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): #entry and label for concept repeats self.repeat_concept_label = ctk.CTkLabel(self.concept_frame_subframe, text="Repeat Concept:") - self.repeat_concept_label.grid(row=4, column=0, sticky="nsew",padx=5,pady=5) + self.repeat_concept_label.grid(row=5, column=0, sticky="nsew",padx=5,pady=5) self.repeat_concept_entry = ctk.CTkEntry(self.concept_frame_subframe,width=200,placeholder_text="1") - self.repeat_concept_entry.grid(row=4, column=1, sticky="e",padx=5,pady=5) + self.repeat_concept_entry.grid(row=5, column=1, sticky="e",padx=5,pady=5) if self.concept.repeat_concept != '': self.repeat_concept_entry.insert(1, self.concept.repeat_concept) #make a label for dataset balancingprocess_sub_dirs self.balance_dataset_label = ctk.CTkLabel(self.concept_frame_subframe, text="Don't Balance Dataset") - self.balance_dataset_label.grid(row=5, column=0, sticky="nsew",padx=5,pady=5) + self.balance_dataset_label.grid(row=6, column=0, sticky="nsew",padx=5,pady=5) #make a switch to enable or disable dataset balancing self.balance_dataset_switch = ctk.CTkSwitch(self.concept_frame_subframe, text="", variable=tk.BooleanVar()) - self.balance_dataset_switch.grid(row=5, column=1, sticky="e",padx=5,pady=5) + self.balance_dataset_switch.grid(row=6, column=1, sticky="e",padx=5,pady=5) if self.concept.concept_do_not_balance == True: self.balance_dataset_switch.toggle() self.process_sub_dirs = ctk.CTkLabel(self.concept_frame_subframe, text="Search Sub-Directories") - self.process_sub_dirs.grid(row=6, column=0, sticky="nsew",padx=5,pady=5) + self.process_sub_dirs.grid(row=7, column=0, sticky="nsew",padx=5,pady=5) #make a switch to enable or disable dataset balancing self.process_sub_dirs_switch = ctk.CTkSwitch(self.concept_frame_subframe, text="", variable=tk.BooleanVar()) - self.process_sub_dirs_switch.grid(row=6, column=1, sticky="e",padx=5,pady=5) + self.process_sub_dirs_switch.grid(row=7, column=1, sticky="e",padx=5,pady=5) if self.concept.process_sub_dirs == True: self.process_sub_dirs_switch.toggle() #self.balance_dataset_switch.set(self.concept.concept_do_not_balance) #add image preview self.image_preview_label = ctk.CTkLabel(self.concept_frame_subframe,text='', width=150, height=150,image=ctk.CTkImage(self.default_image_preview,size=(150,150))) - self.image_preview_label.grid(row=0, column=4,rowspan=5, sticky="nsew",padx=5,pady=5) + self.image_preview_label.grid(row=0, column=4,rowspan=6, sticky="nsew",padx=5,pady=5) if self.concept.image_preview != None or self.concept.image_preview != "": #print(self.concept.image_preview) self.update_preview_image(entry=None,path=None,pil_image=self.concept.image_preview) From 0364a607835b6e7ab51e87d022e7e1bc2877a662 Mon Sep 17 00:00:00 2001 From: Buckzor Date: Thu, 23 Feb 2023 12:09:47 +0000 Subject: [PATCH 8/8] Added separate buckets for concepts. This option puts selected concepts in their own buckets so they do not get batched together with the rest of the images of the same resolution. This should prevent some bleeding over between concepts. --- scripts/configuration_gui.py | 49 ++++++++++++++++++++++++++---------- scripts/dataloaders_util.py | 49 ++++++++++++++++++++++++++++-------- 2 files changed, 74 insertions(+), 24 deletions(-) diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index 4ee2dff..8f9b239 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -76,8 +76,9 @@ def __init__(self, parent, concept=None,width=150,height=150, *args, **kwargs): self.process_sub_dirs = False self.image_preview = self.default_image_preview self.repeat_concept = '1' + self.separate_bucket = False #create concept - self.concept = Concept(self.concept_name, self.concept_data_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, None, self.repeat_concept) + self.concept = Concept(self.concept_name, self.concept_data_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, None, self.repeat_concept, self.separate_bucket ) else: self.concept = concept self.concept.image_preview = self.make_image_preview() @@ -204,7 +205,7 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): self.parent = parent self.conceptWidget = conceptWidget self.concept = concept - self.geometry("576x327") + self.geometry("576x377") self.resizable(False, False) #self.protocol("WM_DELETE_WINDOW", self.on_close) self.wait_visibility() @@ -214,9 +215,9 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): #self.default_image_preview = ImageTk.PhotoImage(self.default_image_preview) #make a frame for the concept window - self.concept_frame = ctk.CTkFrame(self, width=600, height=320) + self.concept_frame = ctk.CTkFrame(self, width=600, height=380) self.concept_frame.grid(row=0, column=0, sticky="nsew",padx=10,pady=10) - self.concept_frame_subframe=ctk.CTkFrame(self.concept_frame, width=600, height=320) + self.concept_frame_subframe=ctk.CTkFrame(self.concept_frame, width=600, height=380) #4 column grid #self.concept_frame.grid_columnconfigure(0, weight=1) #self.concept_frame.grid_columnconfigure(1, weight=5) @@ -301,9 +302,19 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): if self.concept.process_sub_dirs == True: self.process_sub_dirs_switch.toggle() #self.balance_dataset_switch.set(self.concept.concept_do_not_balance) + + #make a label for separate concept buckets + self.separate_bucket_label = ctk.CTkLabel(self.concept_frame_subframe, text="Separate Buckets") + self.separate_bucket_label.grid(row=8, column=0, sticky="nsew",padx=5,pady=5) + #make a switch to enable or disable creation of separate buckets for each concept + self.separate_bucket_switch = ctk.CTkSwitch(self.concept_frame_subframe, text="", variable=tk.BooleanVar()) + self.separate_bucket_switch.grid(row=8, column=1, sticky="e",padx=5,pady=5) + if self.concept.separate_bucket == True: + self.separate_bucket_switch.toggle() + #add image preview self.image_preview_label = ctk.CTkLabel(self.concept_frame_subframe,text='', width=150, height=150,image=ctk.CTkImage(self.default_image_preview,size=(150,150))) - self.image_preview_label.grid(row=0, column=4,rowspan=6, sticky="nsew",padx=5,pady=5) + self.image_preview_label.grid(row=0, column=4,rowspan=7, sticky="nsew",padx=5,pady=5) if self.concept.image_preview != None or self.concept.image_preview != "": #print(self.concept.image_preview) self.update_preview_image(entry=None,path=None,pil_image=self.concept.image_preview) @@ -463,6 +474,8 @@ def save(self): balance_dataset = self.balance_dataset_switch.get() #get concept repeats repeat_concept = self.repeat_concept_entry.get() + #get the separate bucket switch + separate_bucket = self.separate_bucket_switch.get() #create the concept process_sub_dirs = self.process_sub_dirs_switch.get() #image preview @@ -470,14 +483,14 @@ def save(self): #get the main window image_preview_label = self.image_preview_label #update the concept - self.concept.update(concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs,image_preview,image_preview_label,repeat_concept) + self.concept.update(concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs,image_preview,image_preview_label,repeat_concept,separate_bucket) self.conceptWidget.update_button() #close the window self.destroy() #class of the concept class Concept: - def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, balance_dataset=None,process_sub_dirs=None,image_preview=None, image_container=None,repeat_concept=1): + def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, balance_dataset=None,process_sub_dirs=None,image_preview=None, image_container=None,repeat_concept=1, separate_bucket=None): if concept_name == None: concept_name = "" if concept_path == None: @@ -494,6 +507,8 @@ def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, ba balance_dataset = False if process_sub_dirs == None: process_sub_dirs = False + if separate_bucket == None: + separate_bucket = False if image_preview == None: image_preview = "" if image_container == None: @@ -510,8 +525,9 @@ def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, ba self.image_container = image_container self.process_sub_dirs = process_sub_dirs self.repeat_concept = repeat_concept + self.separate_bucket = separate_bucket #update the concept - def update(self, concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs, image_preview, image_container,repeat_concept): + def update(self, concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs, image_preview, image_container,repeat_concept,separate_bucket): self.concept_name = concept_name self.concept_path = concept_path self.concept_class_name = class_name @@ -524,9 +540,10 @@ def update(self, concept_name, concept_path, class_name, class_path,flip_p,balan self.image_container = image_container self.process_sub_dirs = process_sub_dirs self.repeat_concept = repeat_concept + self.separate_bucket = separate_bucket #get the cocept details def get_details(self): - return self.concept_name, self.concept_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, self.image_container, self.repeat_concept + return self.concept_name, self.concept_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, self.image_container, self.repeat_concept, self.separate_bucket #class to make popup right click menu with select all, copy, paste, cut, and delete when right clicked on an entry box class DynamicGrid(ctk.CTkFrame): def __init__(self, parent, *args, **kwargs): @@ -2204,6 +2221,7 @@ def packageForCloud(self): new_concept['do_not_balance'] = concept['do_not_balance'] new_concept['use_sub_dirs'] = concept['use_sub_dirs'] new_concept['repeat_concept'] = concept['repeat_concept'] + new_concept['separate_bucket'] = concept['separate_bucket'] new_concepts.append(new_concept) #make scripts folder self.save_concept_to_json(filename=self.full_export_path + os.sep + 'stabletune_concept_list.json', preMadeConcepts=new_concepts) @@ -2859,7 +2877,7 @@ def save_concept_to_json(self,filename=None,preMadeConcepts=None): concepts = [] for widget in self.concept_widgets: concept = widget.concept - concept_dict = {'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs, 'repeat_concept' : concept.repeat_concept} + concept_dict = {'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs, 'repeat_concept' : concept.repeat_concept, 'separate_bucket' : concept.separate_bucket} concepts.append(concept_dict) if file != None: #write the json to the file @@ -2887,7 +2905,7 @@ def load_concept_from_json(self): concept['flip_p'] = '' if 'repeat_concept' not in concept: concept['repeat_concept'] = '1' - concept = Concept(concept_name=concept["instance_prompt"], class_name=concept["class_prompt"], concept_path=concept["instance_data_dir"], class_path=concept["class_data_dir"],flip_p=concept['flip_p'],balance_dataset=concept["do_not_balance"], process_sub_dirs=concept["use_sub_dirs"], repeat_concept=concept["repeat_concept"]) + concept = Concept(concept_name=concept["instance_prompt"], class_name=concept["class_prompt"], concept_path=concept["instance_data_dir"], class_path=concept["class_data_dir"],flip_p=concept['flip_p'],balance_dataset=concept["do_not_balance"], process_sub_dirs=concept["use_sub_dirs"], repeat_concept=concept["repeat_concept"], separate_bucket=concept["separate_bucket"]) self.add_new_concept(concept) #self.canvas.configure(scrollregion=self.canvas.bbox("all")) self.update() return concept_json @@ -3030,7 +3048,7 @@ def update_concepts(self): self.concepts = [] for i in range(len(self.concept_widgets)): concept = self.concept_widgets[i].concept - self.concepts.append({'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs, 'repeat_concept' : concept.repeat_concept}) + self.concepts.append({'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs, 'repeat_concept' : concept.repeat_concept, 'separate_bucket' : concept.separate_bucket}) def save_config(self, config_file=None): #save the configure file import json @@ -3150,7 +3168,12 @@ def load_config(self,file_name=None): print(configure["concepts"][i].keys()) configure["concepts"][i]['repeat_concept'] = '1' repeat_concept = configure["concepts"][i]["repeat_concept"] - concept = Concept(concept_name=inst_prompt, class_name=class_prompt, concept_path=inst_data_dir, class_path=class_data_dir,flip_p=flip_p,balance_dataset=balance_dataset,process_sub_dirs=process_sub_dirs,repeat_concept=repeat_concept) + + if 'separate_bucket' not in configure["concepts"][i]: + print(configure["concepts"][i].keys()) + configure["concepts"][i]['separate_bucket'] = False + separate_bucket = configure["concepts"][i]["separate_bucket"] + concept = Concept(concept_name=inst_prompt, class_name=class_prompt, concept_path=inst_data_dir, class_path=class_data_dir,flip_p=flip_p,balance_dataset=balance_dataset,process_sub_dirs=process_sub_dirs,repeat_concept=repeat_concept,separate_bucket=separate_bucket) self.add_new_concept(concept) except Exception as e: print(e) diff --git a/scripts/dataloaders_util.py b/scripts/dataloaders_util.py index ebbf659..55603ee 100644 --- a/scripts/dataloaders_util.py +++ b/scripts/dataloaders_util.py @@ -334,6 +334,7 @@ def __init__(self, mask_prompts=None, load_mask=False, repeat_concept=1, + separate_bucket=False, ): self.debug_level = debug_level @@ -394,6 +395,7 @@ def __init__(self, mask_prompts=mask_prompts, load_mask=load_mask, repeat_concept=repeat_concept, + separate_bucket=separate_bucket, ) #print(self.image_train_items) @@ -493,7 +495,7 @@ class ImageTrainItem(): pathname: path to image file flip_p: probability of flipping image (0.0 to 1.0) """ - def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0, model_variant='base', load_mask=False): + def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0, model_variant='base', load_mask=False, separate_bucket_count=0): self.caption = caption self.target_wh = target_wh self.pathname = pathname @@ -506,7 +508,8 @@ def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target self.load_mask=load_mask self.is_dupe = [] self.variant_warning = False - + self.separate_bucket_count = separate_bucket_count + self.image = image self.mask = mask self.extra = extra @@ -760,6 +763,7 @@ def __init__( mask_prompts=None, load_mask=False, repeat_concept=1, + separate_bucket=False, ): self.resolution = resolution self.debug_level = debug_level @@ -775,8 +779,12 @@ def __init__( self.model_variant = model_variant self.extra_module = extra_module self.load_mask = load_mask + self.repeat_concept = repeat_concept + self.separate_bucket=separate_bucket, + separate_bucket_count = 0 prepared_train_data = [] + self.aspects = get_aspect_buckets(resolution) #print(f"* DLMA resolution {resolution}, buckets: {self.aspects}") #process sub directories flag @@ -793,6 +801,7 @@ def __init__( repeat_concept = int(repeat_concept) if repeat_concept > 1: print(f" {bcolors.WARNING} Repeating concept {concept['instance_data_dir']} {repeat_concept} times...{bcolors.ENDC}") + if balance_datasets: print(f" {bcolors.WARNING} Balancing datasets...{bcolors.ENDC}") @@ -834,6 +843,9 @@ def __init__( balance_concept_list.append(min_concept_num_images) else: balance_concept_list.append(min_concept_num_images) + + total_separate_bucket_count = 0 + for concept in concept_list: if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == True: @@ -864,20 +876,33 @@ def __init__( repeat_concept = 1 else: repeat_concept = int(repeat_concept) + + + + if concept['separate_bucket'] == True: + total_separate_bucket_count += 1 + separate_bucket_count = total_separate_bucket_count + else: + separate_bucket_count = 0 + self.__recurse_data_root(self=self, recurse_root=data_root,use_sub_dirs=use_sub_dirs) random.Random(self.seed).shuffle(self.image_paths) if self.model_variant == 'depth2img': print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}") self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) - prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,repeat_concept,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[] + prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,repeat_concept,separate_bucket_count,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[] if add_class_images_to_dataset: self.image_paths = [] self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs) random.Random(self.seed).shuffle(self.image_paths) use_image_names_as_captions = False - prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,repeat_concept,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[] - + prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,repeat_concept,separate_bucket_count,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[] + + + if total_separate_bucket_count > 0: + print(f" {bcolors.WARNING} There are {total_separate_bucket_count} concepts using separate buckets...{bcolors.ENDC}") + self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) if self.with_prior_loss and add_class_images_to_dataset == False: self.class_image_caption_pairs = [] @@ -891,7 +916,7 @@ def __init__( print(f" {bcolors.WARNING} ** Depth2Img To Process Class Dataset{bcolors.ENDC}") self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) use_image_names_as_captions = False - self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,repeat_concept,use_text_files_as_captions=self.use_text_files_as_captions)) + self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,repeat_concept,separate_bucket_count,use_text_files_as_captions=self.use_text_files_as_captions)) self.class_image_caption_pairs = self.__bucketize_images(self.class_image_caption_pairs, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) if mask_prompts is not None: print(f" {bcolors.WARNING} Checking and generating missing masks...{bcolors.ENDC}") @@ -908,7 +933,7 @@ def get_all_images(self): return self.image_caption_pairs else: return self.image_caption_pairs, self.class_image_caption_pairs - def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,repeat_concept=1,use_text_files_as_captions=False): + def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,repeat_concept=1,separate_bucket_count=0,use_text_files_as_captions=False): """ Create ImageTrainItem objects with metadata for hydration later """ @@ -942,7 +967,7 @@ def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_im target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask) + image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask,separate_bucket_count=separate_bucket_count) decorated_image_train_items.append(image_train_item) return decorated_image_train_items @@ -957,10 +982,12 @@ def __bucketize_images(prepared_train_data: list, batch_size=1, debug_level=0,as buckets = {} for image_caption_pair in prepared_train_data: target_wh = image_caption_pair.target_wh + separate_bucket_count = image_caption_pair.separate_bucket_count - if (target_wh[0],target_wh[1]) not in buckets: - buckets[(target_wh[0],target_wh[1])] = [] - buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) + #concept_bucket = image_caption_pair.concept_bucket + if (target_wh[0],target_wh[1],separate_bucket_count) not in buckets: + buckets[(target_wh[0],target_wh[1],separate_bucket_count)] = [] + buckets[(target_wh[0],target_wh[1],separate_bucket_count)].append(image_caption_pair) print(f" ** Number of buckets: {len(buckets)}") for bucket in buckets: bucket_len = len(buckets[bucket])