-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add PAG support for SD Controlnet Img2Img #8810
add PAG support for SD Controlnet Img2Img #8810
Conversation
Hi @yiyixuxu, please review this once. I am having some difficulty with the tests so please have a look at that. |
max_diff = np.abs(image_slice.flatten() - expected_slice).max() | ||
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}" | ||
|
||
# def test_ip_adapter_single(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove these tests, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the commented test will be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you run make style
and make fix-copies
so the quality tests would pass? currently failing
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding support for this! It seems like a lot of the perturbed attention guidance part is not handled yet and this is a raw copy ControlNetSDImg2Img, no? Let me try and help you with the required changes that need to be made:
- Refer to add PAG support #7944 and add PAG support for SD architecture #8725
- Try and understand what happens in pag_utils.py file.
- Take a look at one of the existing PAG pipeline implementations. Notice that at various locations in the code, we check the
self.do_perturbed_attention_guidance
flag and handle things differently from normal CFG. You will have to apply these changes as well. It might be a little tricky to do with controlnet since you also need to take care ofcontrol_model_input
. The easiest way to see all the differences would be to view the diff of non-PAG and PAG variants (for example, StableDiffusionXLPipeline and StableDiffusionXLPAGPipeline) side-by-side - Once you're comfortable and have made all the required changes, try and run through all the different scenarios such as
guess_mode
true and false, with guidance_scale == 1 and guidance_scale > 1, pag_scale == 0 and pag_scale > 0, etc.
If you need any additional help, feel free to ping me any time.
raise AttributeError("Could not access latents of provided encoder_output") | ||
|
||
|
||
def prepare_image(image): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add the missing # Copied from
here?
self.register_to_config(requires_safety_checker=requires_safety_checker) | ||
self.set_pag_applied_layers(pag_applied_layers) | ||
|
||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this method is being used anywhere and so can be removed.
extra_step_kwargs["generator"] = generator | ||
return extra_step_kwargs | ||
|
||
def check_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the missing # Copied from
here as well
"not-safe-for-work" (nsfw) content. | ||
""" | ||
|
||
callback = kwargs.pop("callback", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These callbacks have been deprecated. You can remove them
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | ||
pag_scale: float = 3.0, | ||
pag_adaptive_scale: float = 0.0, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**kwargs, |
else: | ||
assert False | ||
|
||
# 5. Prepare timesteps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might have to fix the step numbering here
with self.progress_bar(total=num_inference_steps) as progress_bar: | ||
for i, t in enumerate(timesteps): | ||
# expand the latents if we are doing classifier free guidance | ||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure the forward pass is working? Shouldn't this be something like
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) |
if callback is not None and i % callback_steps == 0: | ||
step_idx = i // getattr(self.scheduler, "order", 1) | ||
callback(step_idx, t, latents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if callback is not None and i % callback_steps == 0: | |
step_idx = i // getattr(self.scheduler, "order", 1) | |
callback(step_idx, t, latents) |
)[0] | ||
|
||
# perform guidance | ||
if self.do_classifier_free_guidance: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perturbed guidance part does not seem to have been implemented at the different places where it is supposed to be added.
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | ||
|
||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: | ||
image_embeds = self.prepare_ip_adapter_image_embeds( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IP Adapter perturbed embeddings need to be generated differently. Please refer to one of the linked PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @a-r-r-o-w, are you referring to this part here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I'm referring to this: (lines 1156-1177)
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: |
Also, it might help to change the title of this PR to something like "add PAG support for SD Controlnet Img2Img" to reflect the intent correctly when merged |
Hi @a-r-r-o-w, I am trying to fix the coding mistakes and coding style by calling the |
You will need to install make. I'm assuming you are on Windows since make should be available by default on linux or mac. If you install Git for Windows, you will easily be able to use it. Otherwise, try https://stackoverflow.com/questions/32127524/how-to-install-and-use-make-in-windows |
Hi @a-r-r-o-w, I am first completing the pag pipeline and not the associated tests. Please check this whenever you have time. Thanks |
Can you show us some examples with and without PAG enabled? Also, please post the minimal reproducible example. And yes, we can work on tests later in a similar fashion to how others have done in their PRs. |
Hi, I am shifting this PR here. Further communications will be done there. I am closing this one. |
What does this PR do?
Part of #8710
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Tagging @yiyixuxu