Skip to content

TheDenk/images_mixing

Repository files navigation

CLIP Guided Images Mixing With Stable Diffusion

Now you can use this images mixing in official diffusers repo.

This approach allows you to combine two images using standard diffusion models without any prior models. Modified and extended existing clip guided stable diffusion algorithm to work with images.

WARNING: It's hard to get a good result of image mixing the first time.

Code examples description

Open In Colab

All examples you can find in ./jupyters folder:

File Name Description
example-no-CoCa.ipynb Short minimal example for images mixing. The weakness of this approach is that you should write prompts for each image.
example-stable-diffusion-2-base.ipynb Example with stable-diffusion-2-base. For prompt generation CoCa is used.
example-load-by-parts.ipynb Example where each diffusers module is loading separately.
example-find-best-mix-result.ipynb Step by step explained how to get the parameters for mixing. (By complete enumeration of each parameter. xD)
example-as-augmentation.ipynb.ipynb Using image mixing for image augmentation. Summer to winter example.

Short Method Description

Algorithm based on idea of clip guided stable diffusion img2img. But with some modifications:

  • Now two images and (optionaly) two prompts (description of each image) are expected.
  • Using interpolated (content-style) CLIP image embedding. (CLIP text embedding in original)
  • Using interpolated (content-style) text embedding for guidance. (text embedding in original)
  • (Optionaly) Using CoCa model for generation image description

Using different coefficients you can select type of mixing: from style to content or from content to style. Parameters description see below.

Style to prompt and Prompt to style give different result. Example.

Getting Started

git clone https://github.com/TheDenk/images_mixing.git
cd images_mixing

pip -r install requirements.txt

Short Example

import torch
from PIL import Image
from diffusers import DiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPModel

# Loading additional models
feature_extractor = CLIPFeatureExtractor.from_pretrained(
    "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
)
clip_model = CLIPModel.from_pretrained(
    "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
)

# Pipline creating
mixing_pipeline = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    custom_pipeline="./images_mixing.py",
    clip_model=clip_model,
    feature_extractor=feature_extractor,
    torch_dtype=torch.float16,
)
mixing_pipeline.enable_attention_slicing()
mixing_pipeline = mixing_pipeline.to("cuda")

# Pipline running
generator = torch.Generator(device="cuda").manual_seed(117)
content_image = Image.open('./images/boromir.jpg').convert("RGB")
style_image = Image.open('./images/gigachad.jpg').convert("RGB")

pipe_images = mixing_pipeline(
    content_prompt='boromir',
    style_prompt='gigachad',
    num_inference_steps=50,
    content_image=content_image,
    style_image=style_image,
    noise_strength=0.6,
    slerp_latent_style_strength=0.8,
    slerp_prompt_style_strength=0.2,
    slerp_clip_image_style_strength=0.2,
    guidance_scale=9.0,
    batch_size=1,
    clip_guidance_scale=100,
    generator=generator,
).images
pipe_images[0]

Using as augmentation

With Segment anything you can effectively augmenting a dataset of images (Jupyter notebook example).

Short Parameters Description

Each slerp_ parameter has an impact on both images - style and content (more style - less content and and vice versa)

$content strength = 1.0 - stylestrength$

Parameter Name Description
slerp_latent_style_strength parameter has an impact on start noised latent space. Calculate as spherical distance between latent spaces of style image and content image.
slerp_prompt_style_strength parameter has an impact on each diffusion iteration as usual prompt and for clip-guided algorithm. Calculate with CLIP text model as spherical distance between clip text embeddings of style prompt and content prompt.
slerp_clip_image_style_strength parameter has an impact on each diffusion iteration for clip-guided algorithm. Calculate with CLIP image model as spherical distance between clip image embeddings of style image and content image.
noise_strength just noise coefficient. Less value - more original information from start latent space. Recommended minimum value - 0.5, maximum - 0.7.

From style to content recommended start parameters:

noise_strength=0.5
slerp_latent_style_strength=0.8
slerp_prompt_style_strength=0.2
slerp_clip_image_style_strength=0.2

From content to style recommended start parameters:

noise_strength=0.5
slerp_latent_style_strength=0.2
slerp_prompt_style_strength=0.8
slerp_clip_image_style_strength=0.8

Contacts

Issues should be raised directly in the repository. For professional support and recommendations please [email protected].

About

Сombine images using usual diffusion models.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published