Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community] StyleAligned Pipeline #6489

Merged
merged 8 commits into from
Jan 11, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jan 8, 2024

What does this PR do?

Adds the SDXL implementation of Style Aligned.

Code: https://github.com/google/style-aligned
Paper: https://style-aligned-gen.github.io/data/StyleAligned.pdf
Project page: https://style-aligned-gen.github.io/

Fixes #6063.

Some results can be found in the Colab notebook.

GPU support generously provided by ModelsLab.

Before submitting

Who can review?

@sayakpaul @patrickvonplaten @amirhertz

@sayakpaul
Copy link
Member

Thank you!

Could you maybe post some visual examples directly on the PR so that's immediately available to the viewers?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 8, 2024

Below are some results from the Colab notebook above for txt2img.

Normal SDXL Text-to-Image

style-aligned-disabled

SDXL Text-to-Image with StyleAligned

style-aligned-enabled

Below are results for img2img task. Note that the results are better when generating multiple images around same input image. Using different initial input images, as is the case here, leads to artifacts and other problems. One fun case is the influence of multiple images on each other when setting adain_value=True and full_attention_share=True.

Normal SDXL Image-to-Image

image

SDXL Image-to-Image with StyleAligned settings-1

style-aligned-enabled-img2img-2

SDXL Image-to-Image with StyleAligned settings-2

style-aligned-enabled-img2img

Below are results for inpainting.

SDXL Inpainting with StyleAligned settings-1

style-aligned-enabled-inpaint

SDXL Inpainting with StyleAligned settings-2

style-aligned-enabled-inpaint-2

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member Author

@sayakpaul Just a mention for maybe adding this to core diffusers in the future, because it is a very generic approach in terms of implementation, and can be applied with anything as such - sd1.5/sdxl, txt/img/inpaint, controlnet, animatediff, etc. It would be as simple as adding a new attention processor and corresponding .enable_style_aligned() and .disable_style_aligned() methods. Essentially, wherever AttentionProcessor2_0 is used, this can be applied as well.

TLDR of the paper is that you can maintain style consistency over batches of images due to the proposed novel shared attention mechanism. Presently, many people use LoRAs or other finetuning techniques to achieve the same (for example, lora finetuned on realistic car images to obtain cars that look similar to training data) whereas StyleAligned gives you the ability to do so without any need for additional training. You can also combine it with feature-specific loras for improved quality and following a particular style. I've tried using it with futuristic sci-fi loras until now, and it is really really cool!

@sayakpaul
Copy link
Member

@sayakpaul Just a mention for maybe adding this to core diffusers in the future, because it is a very generic approach in terms of implementation, and can be applied with anything as such - sd1.5/sdxl, txt/img/inpaint, controlnet, animatediff, etc. It would be as simple as adding a new attention processor and corresponding .enable_style_aligned() and .disable_style_aligned() methods. Essentially, wherever AttentionProcessor2_0 is used, this can be applied as well.

You know how it is for core :D If we see enough usage for it, we will.

@charchit7
Copy link
Contributor

@sayakpaul Just a mention for maybe adding this to core diffusers in the future, because it is a very generic approach in terms of implementation, and can be applied with anything as such - sd1.5/sdxl, txt/img/inpaint, controlnet, animatediff, etc. It would be as simple as adding a new attention processor and corresponding .enable_style_aligned() and .disable_style_aligned() methods. Essentially, wherever AttentionProcessor2_0 is used, this can be applied as well.

You know how it is for core :D If we see enough usage for it, we will.

:) how do you guys keep traction of usage? any internal tool?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Happy to merge maybe once we have an example on the community folder README.md here: https://github.com/huggingface/diffusers/blob/main/examples/community/README.md

@a-r-r-o-w
Copy link
Member Author

Very cool! Happy to merge maybe once we have an example on the community folder README.md here: https://github.com/huggingface/diffusers/blob/main/examples/community/README.md

Thanks Patrick! I want to work on removing the einops dependency, which I'll get to in a couple of hours, as it's unnecessary here, and also refactor a bit.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review January 9, 2024 20:50
@a-r-r-o-w
Copy link
Member Author

Here's some minimal code for anyone wanting to try it out after merge (as custom_pipeline doesn't work otherwise):

Code
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from PIL import Image

model_id = "a-r-r-o-w/dreamshaper-xl-turbo"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", custom_pipeline="pipeline_sdxl_style_aligned")
pipe = pipe.to("cuda")

# Enable memory saving techniques
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

prompt = [
  "a toy airplane. macro photo. 3d game asset",
  "a toy bicycle. macro photo. 3d game asset",
]
negative_prompt = "low quality, worst quality, "

pipe.enable_style_aligned(
    share_group_norm=False,
    share_layer_norm=False,
    share_attention=True,
    adain_queries=True,
    adain_keys=True,
    adain_values=False,
    full_attention_share=False,
    shared_score_scale=1.0,
    shared_score_shift=0.0,
    only_self_level=0.0,
)

# txt2img
images = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=2,
    height=1024,
    width=1024,
    num_inference_steps=10,
    generator=torch.Generator().manual_seed(42),
).images

# img2img
image = Image.open("/path/to/local/img.png").convert("RGB")
images = pipe(
    image=image,
    prompt=["dog, paper art, origami"] * 2,
    negative_prompt=negative_prompt,
    guidance_scale=2,
    height=1024,
    width=1024,
    num_inference_steps=10,
    generator=torch.Generator().manual_seed(42),
).images

# inpaint
image = Image.open("/path/to/local/img.png").convert("RGB")
mask = Image.open("/path/to/local/mask.png").convert("RGB")
images = pipe(
    image=image,
    mask=mask,
    prompt=["ancient warrior fighting a creature, dragon-like, almost alien"] * 2,
    negative_prompt=negative_prompt,
    guidance_scale=2,
    height=1024,
    width=1024,
    num_inference_steps=10,
    generator=torch.Generator().manual_seed(42),
).images

At the moment, the pipeline support txt2img, img2img and inpainting in one single ginormous file. It can, however, be used for any task and is as simple as setting the attention processor to the SharedAttentionProcessor used here. If there is good amount of usage, or if anyone in the community requires it for a specific task, feel free to ping me for the same and I'll have it ready in no time :)

@patrickvonplaten patrickvonplaten merged commit 9df566e into huggingface:main Jan 11, 2024
14 checks passed
@a-r-r-o-w a-r-r-o-w deleted the stylealigned branch January 11, 2024 13:38
@lingcong-k
Copy link

Hi @a-r-r-o-w thanks for ur awesome work.
I wonder does the pipeline include
inpainting + controlnet + REFERENCE_IMAGE?
I saw the original style_aligned github repo support such

image

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Feb 13, 2024

Hey @lingcong-k, thanks for the kind words! You can use StyleAligned with any pipeline of diffusers. Currently, this pipeline only supports txt2img, img2img and inpaint because community examples are not supposed to be fully featureful in every aspect. You can check the bottom cells in this colab notebook to see how you can use it with any pipeline. If there is more usage of this pipeline, it is possible that the diffusers team would integrate it as a core feature supported in all pipelines. StyleAligned, particularly, introduces an attention processor that can be just swapped in with any existing attention processor (the default is AttnProcessor2_0), so adding it to core in future would be really simple.

Code
!pip install git+https://github.com/huggingface/diffusers.git@main transformers accelerate torchsde
!wget https://raw.githubusercontent.com/a-r-r-o-w/diffusers/62d557a2f567a7b931c05fac4b1137dca5bf0afa/examples/community/pipeline_sdxl_style_aligned.py

from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
from diffusers.utils import load_image, make_image_grid
from PIL import Image
import cv2
import numpy as np
import torch

original_image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)

image = np.array(original_image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
make_image_grid([original_image, canny_image], rows=1, cols=2)

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "a-r-r-o-w/dreamshaper-xl-turbo",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
)
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

from pipeline_sdxl_style_aligned import SharedAttentionProcessor, get_switch_vec
from typing import Dict, List, Union

import torch.nn as nn
from diffusers.models.attention_processor import AttnProcessor2_0

def register_shared_norm(self, share_group_norm: bool = True, share_layer_norm: bool = True):
    r"""Helper method to register shared group/layer normalization layers."""

    def register_norm_forward(norm_layer: Union[nn.GroupNorm, nn.LayerNorm]) -> Union[nn.GroupNorm, nn.LayerNorm]:
        if not hasattr(norm_layer, "orig_forward"):
            setattr(norm_layer, "orig_forward", norm_layer.forward)
        orig_forward = norm_layer.orig_forward

        def forward_(hidden_states: torch.Tensor) -> torch.Tensor:
            n = hidden_states.shape[-2]
            hidden_states = concat_first(hidden_states, dim=-2)
            hidden_states = orig_forward(hidden_states)
            return hidden_states[..., :n, :]

        norm_layer.forward = forward_
        return norm_layer

    def get_norm_layers(pipeline_, norm_layers_: Dict[str, List[Union[nn.GroupNorm, nn.LayerNorm]]]):
        if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm:
            norm_layers_["layer"].append(pipeline_)
        if isinstance(pipeline_, nn.GroupNorm) and share_group_norm:
            norm_layers_["group"].append(pipeline_)
        else:
            for layer in pipeline_.children():
                get_norm_layers(layer, norm_layers_)

    norm_layers = {"group": [], "layer": []}
    get_norm_layers(self.unet, norm_layers)

    norm_layers_list = []
    for key in ["group", "layer"]:
        for layer in norm_layers[key]:
            norm_layers_list.append(register_norm_forward(layer))

    return norm_layers_list

def enable_shared_attention_processors(
    self,
    share_attention: bool,
    adain_queries: bool,
    adain_keys: bool,
    adain_values: bool,
    full_attention_share: bool,
    shared_score_scale: float,
    shared_score_shift: float,
    only_self_level: float,
):
    r"""Helper method to enable usage of Shared Attention Processor."""
    attn_procs = {}
    num_self_layers = len([name for name in self.unet.attn_processors.keys() if "attn1" in name])

    only_self_vec = get_switch_vec(num_self_layers, only_self_level)

    for i, name in enumerate(self.unet.attn_processors.keys()):
        is_self_attention = "attn1" in name
        if is_self_attention:
            if only_self_vec[i // 2]:
                attn_procs[name] = AttnProcessor2_0()
            else:
                attn_procs[name] = SharedAttentionProcessor(
                    share_attention=share_attention,
                    adain_queries=adain_queries,
                    adain_keys=adain_keys,
                    adain_values=adain_values,
                    full_attention_share=full_attention_share,
                    shared_score_scale=shared_score_scale,
                    shared_score_shift=shared_score_shift,
                )
        else:
            attn_procs[name] = AttnProcessor2_0()

    self.unet.set_attn_processor(attn_procs)

share_group_norm=False
share_layer_norm=False
share_attention=True
adain_queries=True
adain_keys=True
adain_values=False
full_attention_share=False
shared_score_scale=1.0
shared_score_shift=0.0
only_self_level=0.0

pipe.style_aligned_norm_layers = register_shared_norm(pipe, share_group_norm, share_layer_norm)
enable_shared_attention_processors(
    pipe,
    share_attention=share_attention,
    adain_queries=adain_queries,
    adain_keys=adain_keys,
    adain_values=adain_values,
    full_attention_share=full_attention_share,
    shared_score_scale=shared_score_scale,
    shared_score_shift=shared_score_shift,
    only_self_level=only_self_level,
)
pipe = pipe.to("cuda")

prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = 'low quality, bad quality, sketches'

image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    image=canny_image,
    controlnet_conditioning_scale=0.5,
    guidance_scale=2,
    num_inference_steps=20,
).images[0]
make_image_grid([original_image, canny_image, image], rows=1, cols=3)

multi_images = pipe(
    prompt=[prompt, "an astronaut floating in space, futuristic, high quality"],
    negative_prompt=[negative_prompt] * 2,
    image=canny_image,
    controlnet_conditioning_scale=0.5,
    guidance_scale=2,
    num_inference_steps=20,
).images
make_image_grid(multi_images, rows=1, cols=2)

image

@lingcong-k
Copy link

Hey @lingcong-k, thanks for the kind words! You can use StyleAligned with any pipeline of diffusers. Currently, this pipeline only supports txt2img, img2img and inpaint because community examples are not supposed to be fully featureful in every aspect. You can check the bottom cells in this colab notebook to see how you can use it with any pipeline. If there is more usage of this pipeline, it is possible that the diffusers team would integrate it as a core feature supported in all pipelines. StyleAligned, particularly, introduces an attention processor that can be just swapped in with any existing attention processor (the default is AttnProcessor2_0), so adding it to core in future would be really simple.

Code
image

wow thanks so much for your detailed answer. that helps a lot :)

@a-r-r-o-w
Copy link
Member Author

@sayakpaul @DN6 @yiyixuxu What are your thoughts on supporting the StyleAligned attention processor from within diffusers? Many recent works on stylistic generation have cited this method and developed improvements on it. There is definitely some traction in the community and I am too working on a new method that compares against this in what would be my first research paper. Thank you for your time.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 8, 2024

cc @asomoza here - can you take a look?

@asomoza
Copy link
Member

asomoza commented Apr 10, 2024

This works but IMO, I see two problems with it:

  1. Once you generate the images you lose the style
  2. You are forced to use batch generations, and the VRAM consumption is already restrictive for SDXL users. This approach uses even more VRAM, so many people will not be able to use it.

For example, this is what I got with the style "macro photo. 3d game asset":

style aligned style aligned style aligned style aligned
style_aligned_20240410131107_42_0 style_aligned_20240410131107_42_1 style_aligned_20240410131107_42_2 style_aligned_20240410131107_42_3

Using the first image with an IP Adapter (style):

IP Adapter IP Adapter IP Adapter
20240410134009 20240410134044 20240410134057

I can generate an infinite number of images using the same style. In fact, in my opinion, the IP Adapter transfers the style better, but that’s just my opinion and is based on not that many tests.

I'm still curious, @a-r-r-o-w, in your opinion, what would be the benefit of using this versus IP Adapters?

I still think this is nice but I don't see that many people using it. Specially now with the release of InstantStyle or the ComfyUI IP Adapter node.

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add stylealigned sdxl pipeline

* bugfix

* update docs

* remove einops dependency

* update README

* update example docstring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding StyleAligned
8 participants