From 9cd37557d581dd30fb6031ae30bd583443c3effd Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 25 Sep 2024 19:09:54 -1000 Subject: [PATCH] flux controlnet fix (control_modes batch & others) (#9507) * flux controlnet mode to take into account batch size * incorporate yiyixuxu's suggestions (cleaner logic) as well as clean up control mode handling for multi case * fix * fix use_guidance when controlnet is a multi and does not have config --------- Co-authored-by: Christopher Beckham Co-authored-by: Sayak Paul --- src/diffusers/models/controlnet_flux.py | 23 +++++------ .../flux/pipeline_flux_controlnet.py | 39 ++++++++++++------- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 036e5654a98e..88ad49d2b776 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -502,16 +502,17 @@ def forward( control_block_samples = block_samples control_single_block_samples = single_block_samples else: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) - ] - - control_single_block_samples = [ - control_single_block_sample + block_sample - for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples - ) - ] + if block_samples is not None and control_block_samples is not None: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + if single_block_samples is not None and control_single_block_samples is not None: + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] return control_block_samples, control_single_block_samples diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 11b71b1cbece..6c072c482020 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -747,10 +747,12 @@ def __call__( width_control_image, ) - # set control mode + # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: + if not isinstance(control_mode, int): + raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`") control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] @@ -785,16 +787,22 @@ def __call__( control_image = control_images + # Here we ensure that `control_mode` has the same length as the control_image. + if isinstance(control_mode, list) and len(control_mode) != len(control_image): + raise ValueError( + "For Multi-ControlNet, `control_mode` must be a list of the same " + + " length as the number of controlnets (control images) specified" + ) + if not isinstance(control_mode, list): + control_mode = [control_mode] * len(control_image) # set control mode - control_mode_ = [] - if isinstance(control_mode, list): - for cmode in control_mode: - if cmode is None: - control_mode_.append(-1) - else: - control_mode_.append(cmode) - control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + control_modes = [] + for cmode in control_mode: + if cmode is None: + cmode = -1 + control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long) + control_modes.append(control_mode) + control_mode = control_modes # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -840,9 +848,12 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - guidance = ( - torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None - ) + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + + guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None guidance = guidance.expand(latents.shape[0]) if guidance is not None else None # controlnet