diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md
index 835b4d7b7fdd..bfd6ab973d5e 100644
--- a/docs/source/en/api/pipelines/animatediff.md
+++ b/docs/source/en/api/pipelines/animatediff.md
@@ -25,6 +25,9 @@ The abstract of the paper is the following:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
+| [AnimateDiffControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py) | *Controlled Video-to-Video Generation with AnimateDiff using ControlNet* |
+| [AnimateDiffSparseControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py) | *Controlled Video-to-Video Generation with AnimateDiff using SparseCtrl* |
+| [AnimateDiffSDXLPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py) | *Video-to-Video Generation with AnimateDiff* |
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
## Available checkpoints
@@ -100,6 +103,83 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
+### AnimateDiffControlNetPipeline
+
+AnimateDiff can also be used with ControlNets ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala. With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide depth maps, the ControlNet model generates a video that'll preserve the spatial information from the depth maps. It is a more flexible and accurate way to control the video generation process.
+
+```python
+import torch
+from diffusers import AnimateDiffControlNetPipeline, AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
+from diffusers.utils import export_to_gif, load_video
+
+# Additionally, you will need a preprocess videos before they can be used with the ControlNet
+# HF maintains just the right package for it: `pip install controlnet_aux`
+from controlnet_aux.processor import ZoeDetector
+
+# Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file
+# Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()
+controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16)
+
+# We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
+motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
+
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
+pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
+ "SG161222/Realistic_Vision_V5.1_noVAE",
+ motion_adapter=motion_adapter,
+ controlnet=controlnet,
+ vae=vae,
+).to(device="cuda", dtype=torch.float16)
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
+pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
+pipe.set_adapters(["lcm-lora"], [0.8])
+
+depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
+video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
+conditioning_frames = []
+
+with pipe.progress_bar(total=len(video)) as progress_bar:
+ for frame in video:
+ conditioning_frames.append(depth_detector(frame))
+ progress_bar.update()
+
+prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality"
+negative_prompt = "bad quality, worst quality"
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=len(video),
+ num_inference_steps=10,
+ guidance_scale=2.0,
+ conditioning_frames=conditioning_frames,
+ generator=torch.Generator().manual_seed(42),
+).frames[0]
+
+export_to_gif(video, "animatediff_controlnet.gif", fps=8)
+```
+
+Here are some sample outputs:
+
+
+
+ Source Video |
+ Output Video |
+
+
+
+ raccoon playing a guitar
+
+
+ |
+
+ a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality
+
+
+ |
+
+
+
### AnimateDiffSparseControlNetPipeline
[SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.
@@ -762,6 +842,12 @@ pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapt
- all
- __call__
+## AnimateDiffControlNetPipeline
+
+[[autodoc]] AnimateDiffControlNetPipeline
+ - all
+ - __call__
+
## AnimateDiffSparseControlNetPipeline
[[autodoc]] AnimateDiffSparseControlNetPipeline
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 10bda1316bd7..9bb7be4b0dd6 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -232,6 +232,7 @@
"AmusedImg2ImgPipeline",
"AmusedInpaintPipeline",
"AmusedPipeline",
+ "AnimateDiffControlNetPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
@@ -652,6 +653,7 @@
AmusedImg2ImgPipeline,
AmusedInpaintPipeline,
AmusedPipeline,
+ AnimateDiffControlNetPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index f1d41b60d090..e1a56efa851d 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -118,6 +118,7 @@
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
+ "AnimateDiffControlNetPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
@@ -419,6 +420,7 @@
else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import (
+ AnimateDiffControlNetPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
diff --git a/src/diffusers/pipelines/animatediff/__init__.py b/src/diffusers/pipelines/animatediff/__init__.py
index 5f4d6f29391c..3ee72bc44003 100644
--- a/src/diffusers/pipelines/animatediff/__init__.py
+++ b/src/diffusers/pipelines/animatediff/__init__.py
@@ -22,6 +22,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
+ _import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"]
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
@@ -35,6 +36,7 @@
else:
from .pipeline_animatediff import AnimateDiffPipeline
+ from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
new file mode 100644
index 000000000000..5574556a9b96
--- /dev/null
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
@@ -0,0 +1,1066 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ...image_processor import PipelineImageInput
+from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...models.unets.unet_motion_model import MotionAdapter
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...video_processor import VideoProcessor
+from ..controlnet.multicontrolnet import MultiControlNetModel
+from ..free_init_utils import FreeInitMixin
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from .pipeline_output import AnimateDiffPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... AnimateDiffControlNetPipeline,
+ ... AutoencoderKL,
+ ... ControlNetModel,
+ ... MotionAdapter,
+ ... LCMScheduler,
+ ... )
+ >>> from diffusers.utils import export_to_gif, load_video
+
+ >>> # Additionally, you will need a preprocess videos before they can be used with the ControlNet
+ >>> # HF maintains just the right package for it: `pip install controlnet_aux`
+ >>> from controlnet_aux.processor import ZoeDetector
+
+ >>> # Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file
+ >>> # Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()
+ >>> controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16)
+
+ >>> # We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
+ >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
+
+ >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
+ >>> pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
+ ... "SG161222/Realistic_Vision_V5.1_noVAE",
+ ... motion_adapter=motion_adapter,
+ ... controlnet=controlnet,
+ ... vae=vae,
+ ... ).to(device="cuda", dtype=torch.float16)
+ >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
+ >>> pipe.load_lora_weights(
+ ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora"
+ ... )
+ >>> pipe.set_adapters(["lcm-lora"], [0.8])
+
+ >>> depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
+ ... )
+ >>> conditioning_frames = []
+
+ >>> with pipe.progress_bar(total=len(video)) as progress_bar:
+ ... for frame in video:
+ ... conditioning_frames.append(depth_detector(frame))
+ ... progress_bar.update()
+
+ >>> prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality"
+ >>> negative_prompt = "bad quality, worst quality"
+
+ >>> video = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... num_frames=len(video),
+ ... num_inference_steps=10,
+ ... guidance_scale=2.0,
+ ... conditioning_frames=conditioning_frames,
+ ... generator=torch.Generator().manual_seed(42),
+ ... ).frames[0]
+
+ >>> export_to_gif(video, "animatediff_controlnet.gif", fps=8)
+ ```
+"""
+
+
+class AnimateDiffControlNetPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ FreeInitMixin,
+):
+ r"""
+ Pipeline for text-to-video generation with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
+ motion_adapter: MotionAdapter,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ feature_extractor: Optional[CLIPImageProcessor] = None,
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
+ ):
+ super().__init__()
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_video_processor = VideoProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ def decode_latents(self, latents, decode_batch_size: int = 16):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ video = []
+ for i in range(0, latents.shape[0], decode_batch_size):
+ batch_latents = latents[i : i + decode_batch_size]
+ batch_latents = self.vae.decode(batch_latents).sample
+ video.append(batch_latents)
+
+ video = torch.cat(video)
+ video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ num_frames,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ video=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(video, list):
+ raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(video)}")
+ if len(video) != num_frames:
+ raise ValueError(f"Excepted image to have length {num_frames} but got {len(video)=}")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(video, list) or not isinstance(video[0], list):
+ raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}")
+ if len(video[0]) != num_frames:
+ raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}")
+ if any(len(img) != len(video[0]) for img in video):
+ raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_video(
+ self,
+ video,
+ width,
+ height,
+ batch_size,
+ num_videos_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(
+ dtype=torch.float32
+ )
+ video = video.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ video_batch_size = video.shape[0]
+
+ if video_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_videos_per_prompt
+
+ video = video.repeat_interleave(repeat_by, dim=0)
+ video = video.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ video = torch.cat([video] * 2)
+
+ return video
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_frames: Optional[int] = 16,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
+ conditioning_frames: Optional[List[PipelineImageInput]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ decode_batch_size: int = 16,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_frames (`int`, *optional*, defaults to 16):
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
+ amounts to 2 seconds of video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ conditioning_frames (`List[PipelineImageInput]`, *optional*):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If multiple
+ ControlNets are specified, images must be passed as a list such that each element of the list can be
+ correctly batched for input to a single ControlNet.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ negative_prompt=negative_prompt,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ video=conditioning_frames,
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
+ control_guidance_start=control_guidance_start,
+ control_guidance_end=control_guidance_end,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ if isinstance(controlnet, ControlNetModel):
+ conditioning_frames = self.prepare_video(
+ video=conditioning_frames,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ cond_prepared_videos = []
+ for frame_ in conditioning_frames:
+ prepared_video = self.prepare_video(
+ video=frame_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ cond_prepared_videos.append(prepared_video)
+ conditioning_frames = cond_prepared_videos
+ else:
+ assert False
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
+ for free_init_iter in range(num_free_init_iters):
+ if self.free_init_enabled:
+ latents, timesteps = self._apply_free_init(
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
+ )
+
+ self._num_timesteps = len(timesteps)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8. Denoising loop
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ control_model_input = torch.transpose(control_model_input, 1, 2)
+ control_model_input = control_model_input.reshape(
+ (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
+ )
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=conditioning_frames,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 9. Post processing
+ if output_type == "latent":
+ video = latents
+ else:
+ video_tensor = self.decode_latents(latents, decode_batch_size)
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
+
+ # 10. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index f41edfcda3d8..d2633d2ec9e7 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -93,7 +93,7 @@
is_xformers_available,
requires_backends,
)
-from .loading_utils import load_image
+from .loading_utils import load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 105c1a5def9d..ead85a2a498e 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -77,6 +77,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class AnimateDiffControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AnimateDiffPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py
index bb5307756491..1a3270f759e5 100644
--- a/src/diffusers/utils/export_utils.py
+++ b/src/diffusers/utils/export_utils.py
@@ -9,10 +9,7 @@
import PIL.Image
import PIL.ImageOps
-from .import_utils import (
- BACKENDS_MAPPING,
- is_opencv_available,
-)
+from .import_utils import BACKENDS_MAPPING, is_opencv_available
from .logging import get_logger
diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py
index aa087e981731..9d13dcd6cccb 100644
--- a/src/diffusers/utils/loading_utils.py
+++ b/src/diffusers/utils/loading_utils.py
@@ -1,13 +1,16 @@
import os
-from typing import Callable, Union
+import tempfile
+from typing import Callable, List, Optional, Union
import PIL.Image
import PIL.ImageOps
import requests
+from .import_utils import BACKENDS_MAPPING, is_opencv_available
+
def load_image(
- image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
+ image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
) -> PIL.Image.Image:
"""
Loads `image` to a PIL Image.
@@ -15,7 +18,7 @@ def load_image(
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
- convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
+ convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
"RGB".
@@ -47,3 +50,73 @@ def load_image(
image = image.convert("RGB")
return image
+
+
+def load_video(
+ video: str,
+ convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
+) -> List[PIL.Image.Image]:
+ """
+ Loads `video` to a list of PIL Image.
+
+ Args:
+ video (`str`):
+ A URL or Path to a video to convert to a list of PIL Image format.
+ convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
+ A conversion method to apply to the video after loading it. When set to `None` the images will be converted
+ to "RGB".
+
+ Returns:
+ `List[PIL.Image.Image]`:
+ The video as a list of PIL images.
+ """
+ is_url = video.startswith("http://") or video.startswith("https://")
+ is_file = os.path.isfile(video)
+ was_tempfile_created = False
+
+ if not (is_url or is_file):
+ raise ValueError(
+ f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
+ )
+
+ if is_url:
+ video_data = requests.get(video, stream=True).raw
+ video_path = tempfile.NamedTemporaryFile(suffix=os.path.splitext(video)[1], delete=False).name
+ was_tempfile_created = True
+ with open(video_path, "wb") as f:
+ f.write(video_data.read())
+
+ video = video_path
+
+ pil_images = []
+ if video.endswith(".gif"):
+ gif = PIL.Image.open(video)
+ try:
+ while True:
+ pil_images.append(gif.copy())
+ gif.seek(gif.tell() + 1)
+ except EOFError:
+ pass
+
+ else:
+ if is_opencv_available():
+ import cv2
+ else:
+ raise ImportError(BACKENDS_MAPPING["opencv"][1].format("load_video"))
+
+ video_capture = cv2.VideoCapture(video)
+ success, frame = video_capture.read()
+ while success:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ pil_images.append(PIL.Image.fromarray(frame))
+ success, frame = video_capture.read()
+
+ video_capture.release()
+
+ if was_tempfile_created:
+ os.remove(video_path)
+
+ if convert_method is not None:
+ pil_images = convert_method(pil_images)
+
+ return pil_images
diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py
new file mode 100644
index 000000000000..f46e514f9ac0
--- /dev/null
+++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py
@@ -0,0 +1,431 @@
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AnimateDiffControlNetPipeline,
+ AutoencoderKL,
+ ControlNetModel,
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ LCMScheduler,
+ MotionAdapter,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+ UNetMotionModel,
+)
+from diffusers.utils import logging
+from diffusers.utils.testing_utils import torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ IPAdapterTesterMixin,
+ PipelineFromPipeTesterMixin,
+ PipelineTesterMixin,
+ SDFunctionTesterMixin,
+)
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+class AnimateDiffControlNetPipelineFastTests(
+ IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
+):
+ pipeline_class = AnimateDiffControlNetPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"conditioning_frames"})
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ def get_dummy_components(self):
+ cross_attention_dim = 8
+ block_out_channels = (8, 8)
+
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=block_out_channels,
+ layers_per_block=2,
+ sample_size=8,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=2,
+ )
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="linear",
+ clip_sample=False,
+ )
+ torch.manual_seed(0)
+ controlnet = ControlNetModel(
+ block_out_channels=block_out_channels,
+ layers_per_block=2,
+ in_channels=4,
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
+ cross_attention_dim=cross_attention_dim,
+ conditioning_embedding_out_channels=(8, 8),
+ norm_num_groups=1,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=block_out_channels,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=2,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=cross_attention_dim,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ motion_adapter = MotionAdapter(
+ block_out_channels=block_out_channels,
+ motion_layers_per_block=2,
+ motion_norm_num_groups=2,
+ motion_num_attention_heads=4,
+ )
+
+ components = {
+ "unet": unet,
+ "controlnet": controlnet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "motion_adapter": motion_adapter,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "feature_extractor": None,
+ "image_encoder": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 2):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ video_height = 32
+ video_width = 32
+ conditioning_frames = [Image.new("RGB", (video_width, video_height))] * num_frames
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "conditioning_frames": conditioning_frames,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "num_frames": num_frames,
+ "guidance_scale": 7.5,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_from_pipe_consistent_config(self):
+ assert self.original_pipeline_class == StableDiffusionPipeline
+ original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
+ original_kwargs = {"requires_safety_checker": False}
+
+ # create original_pipeline_class(sd)
+ pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
+
+ # original_pipeline_class(sd) -> pipeline_class
+ pipe_components = self.get_dummy_components()
+ pipe_additional_components = {}
+ for name, component in pipe_components.items():
+ if name not in pipe_original.components:
+ pipe_additional_components[name] = component
+
+ pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
+
+ # pipeline_class -> original_pipeline_class(sd)
+ original_pipe_additional_components = {}
+ for name, component in pipe_original.components.items():
+ if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
+ original_pipe_additional_components[name] = component
+
+ pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
+
+ # compare the config
+ original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
+ original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
+ assert original_config_2 == original_config
+
+ def test_motion_unet_loading(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ assert isinstance(pipe.unet, UNetMotionModel)
+
+ @unittest.skip("Attention slicing is not enabled in this pipeline")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_ip_adapter_single(self):
+ expected_pipe_slice = None
+ if torch_device == "cpu":
+ expected_pipe_slice = np.array(
+ [
+ 0.6604,
+ 0.4099,
+ 0.4928,
+ 0.5706,
+ 0.5096,
+ 0.5012,
+ 0.6051,
+ 0.5169,
+ 0.5021,
+ 0.4864,
+ 0.4261,
+ 0.5779,
+ 0.5822,
+ 0.4049,
+ 0.5253,
+ 0.6160,
+ 0.4150,
+ 0.5155,
+ ]
+ )
+ return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
+
+ def test_dict_tuple_outputs_equivalent(self):
+ expected_slice = None
+ if torch_device == "cpu":
+ expected_slice = np.array([0.6051, 0.5169, 0.5021, 0.6160, 0.4150, 0.5155])
+ return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
+
+ def test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ additional_params_copy_to_batched_inputs=["num_inference_steps"],
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for components in pipe.components.values():
+ if hasattr(components, "set_default_attn_processor"):
+ components.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+ batched_inputs[name][-1] = 100 * "very long"
+
+ else:
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ for arg in additional_params_copy_to_batched_inputs:
+ batched_inputs[arg] = inputs[arg]
+
+ output = pipe(**inputs)
+ output_batch = pipe(**batched_inputs)
+
+ assert output_batch[0].shape[0] == batch_size
+
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
+ assert max_diff < expected_max_diff
+
+ @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ def test_prompt_embeds(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("prompt")
+ inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
+ pipe(**inputs)
+
+ def test_free_init(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs_normal = self.get_dummy_inputs(torch_device)
+ frames_normal = pipe(**inputs_normal).frames[0]
+
+ pipe.enable_free_init(
+ num_iters=2,
+ use_fast_sampling=True,
+ method="butterworth",
+ order=4,
+ spatial_stop_frequency=0.25,
+ temporal_stop_frequency=0.25,
+ )
+ inputs_enable_free_init = self.get_dummy_inputs(torch_device)
+ frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
+
+ pipe.disable_free_init()
+ inputs_disable_free_init = self.get_dummy_inputs(torch_device)
+ frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
+
+ sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
+ max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
+ self.assertGreater(
+ sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
+ )
+ self.assertLess(
+ max_diff_disabled,
+ 1e-4,
+ "Disabling of FreeInit should lead to results similar to the default pipeline results",
+ )
+
+ def test_free_init_with_schedulers(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs_normal = self.get_dummy_inputs(torch_device)
+ frames_normal = pipe(**inputs_normal).frames[0]
+
+ schedulers_to_test = [
+ DPMSolverMultistepScheduler.from_config(
+ components["scheduler"].config,
+ timestep_spacing="linspace",
+ beta_schedule="linear",
+ algorithm_type="dpmsolver++",
+ steps_offset=1,
+ clip_sample=False,
+ ),
+ LCMScheduler.from_config(
+ components["scheduler"].config,
+ timestep_spacing="linspace",
+ beta_schedule="linear",
+ steps_offset=1,
+ clip_sample=False,
+ ),
+ ]
+ components.pop("scheduler")
+
+ for scheduler in schedulers_to_test:
+ components["scheduler"] = scheduler
+ pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ frames_enable_free_init = pipe(**inputs).frames[0]
+ sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
+
+ self.assertGreater(
+ sum_enabled,
+ 1e1,
+ "Enabling of FreeInit should lead to results different from the default pipeline results",
+ )
+
+ def test_vae_slicing(self, video_count=2):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["prompt"] = [inputs["prompt"]] * video_count
+ inputs["conditioning_frames"] = [inputs["conditioning_frames"]] * video_count
+ output_1 = pipe(**inputs)
+
+ # make sure sliced vae decode yields the same result
+ pipe.enable_vae_slicing()
+ inputs = self.get_dummy_inputs(device)
+ inputs["prompt"] = [inputs["prompt"]] * video_count
+ inputs["conditioning_frames"] = [inputs["conditioning_frames"]] * video_count
+ output_2 = pipe(**inputs)
+
+ assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2