Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PAG support for sd_1.5 controlnet #8816

Closed

Conversation

tuanh123789
Copy link
Contributor

@tuanh123789 tuanh123789 commented Jul 9, 2024

What does this PR do?

Adds PAG (Perturbed-Attention Guidance) support for SD 1.5 with controlnet models (StableDiffusionControlNetPAGPipeline)
Continuation of #8710

Before submitting

Who can review?

@yiyixuxu
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Sample using SD1.5 controlnet Sample SD1.5 controlnet PAG
test test_10
from diffusers import AutoPipelineForText2Image, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import numpy as np
import torch

import cv2
from PIL import Image

# download an image
image = load_image(
   "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
image = np.array(image)

# get canny image
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = AutoPipelineForText2Image.from_pretrained(
   "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True
)

# speed up diffusion process with faster scheduler and memory optimization
# remove following line if xformers is not installed
pipe.enable_xformers_memory_efficient_attention()

pipe.enable_model_cpu_offload()

# generate image
generator = torch.manual_seed(0)
image = pipe(
   "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting",
   guidance_scale=7.5,
   generator=generator,
   image=canny_image,
   pag_scale=10,
).images[0]

@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu July 9, 2024 09:40
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for taking this up! it's looking good to me for the most part but requesting a few minor changes and then this should be ready for merge. also need you to rebase/merge your branch on latest diffusers:main correctly in order for the SD3 Inpaint pipeline changes to reflect because they are all deleted at the moment

@@ -288,12 +288,12 @@
"StableCascadePriorPipeline",
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this been removed? Could you rebase/cherrypick your branch on current diffusers main?

@@ -527,11 +525,7 @@
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
from .stable_diffusion_3 import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to revert these changes

@@ -962,7 +962,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusion3InpaintPipeline(metaclass=DummyObject):
class StableDiffusion3Pipeline(metaclass=DummyObject):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to revert these changes and re-run make fix-copies

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


EXAMPLE_DOC_STRING = """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

self.set_pag_applied_layers(pag_applied_layers)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is deprecated and therefore can be removed from here

Comment on lines +1007 to +1012
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.

Comment on lines +1056 to +1070
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)

These have been deprecated and we do not need to have them in newer pipelines. You'll have to make the necessary changes elsewhere too to remove any usage occurences

prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IP Adapter image embeds need to be prepared differently for perturbed attention guidance. Please take a look here:

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:


# 7.1 Add image embeds for IP-Adapter

if ip_adapter_image_embeds is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see you've done what I asked in previous comment here. Is there a reason to do it separately? If not, can you move it under the same if-branch?

Comment on lines +1374 to +1376
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

it's been deprecated

@tuanh123789 tuanh123789 closed this Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants