Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latte: Latent Diffusion Transformer for Video Generation #8404

Merged
merged 87 commits into from
Jul 11, 2024

Conversation

maxin-cn
Copy link
Contributor

@maxin-cn maxin-cn commented Jun 5, 2024

What does this PR do?

Add Latte to diffusers. Please see this issue .

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu June 5, 2024 12:02
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some feedbacks!
thanks!

src/diffusers/models/transformers/latte_transformer_3d.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/latte_transformer_3d.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/latte_transformer_3d.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/latte_transformer_3d.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/latte_transformer_3d.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/latte_transformer_3d.py Outdated Show resolved Hide resolved
src/diffusers/models/attention.py Outdated Show resolved Hide resolved
src/diffusers/models/attention.py Outdated Show resolved Hide resolved
@@ -45,6 +45,7 @@
)
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .latte import LattePipeline
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .latte import LattePipeline

don't need this import here:)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed it.

@maxin-cn
Copy link
Contributor Author

maxin-cn commented Jun 6, 2024

I left some feedbacks! thanks!

Hi, @yiyixuxu thanks for your code review. I've removed some unnecessary codes from latte_transformer_3d.py. norm_type has not been modified yet. Do you have any better modification suggestions?

Copy link
Contributor Author

@maxin-cn maxin-cn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I finished my code review. In summary, the unused code in latte_transformer_3d.py was removed, the norm_type_latte was removed and a flag squeeze_hidden_states was added

src/diffusers/models/attention.py Outdated Show resolved Hide resolved
src/diffusers/models/attention.py Outdated Show resolved Hide resolved
@maxin-cn maxin-cn requested a review from yiyixuxu June 6, 2024 01:01
@maxin-cn
Copy link
Contributor Author

maxin-cn commented Jun 6, 2024

I finished my code review. In summary, the unused code in latte_transformer_3d.py was removed, the norm_type_latte was removed and a flag squeeze_hidden_states was added. Thanks for further the code review.

@maxin-cn maxin-cn requested a review from yiyixuxu June 6, 2024 09:48
.github/ISSUE_TEMPLATE/bug-report.yml Outdated Show resolved Hide resolved
.gitignore Outdated Show resolved Hide resolved
.vscode/sftp.json Outdated Show resolved Hide resolved
@maxin-cn
Copy link
Contributor Author

Hey @maxin-cn, a couple of things still remain:

  • Inference does not work with the example code because the model does not load correctly with from_pretrained. I had to do something like:

Code

import torch
from diffusers import LattePipeline, AutoencoderKL
from diffusers.utils import export_to_gif
from transformers import T5EncoderModel

text_encoder = T5EncoderModel.from_pretrained("maxin-cn/Latte-1", subfolder="text_encoder", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("maxin-cn/Latte-1", subfolder="vae", torch_dtype=torch.float16)
pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", text_encoder=text_encoder, vae=vae, torch_dtype=torch.float16).to("cuda")

prompt = "A small cactus with a happy face in the Sahara desert."
videos = pipe(prompt, video_length=16).frames
export_to_gif(videos, "latte.gif")

If we use the current example code, there is an error:
Error

transformer/diffusion_pytorch_model.safetensors not found
Loading pipeline components...:   0%|                                                                                                                               | 0/5 [00:00<?, ?it/s]An error occurred while trying to fetch /home/ubuntu/models/hub/models--maxin-cn--Latte-1/snapshots/7ae199353e4324f0216c93c5e8e408d71cb3350a/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/ubuntu/models/hub/models--maxin-cn--Latte-1/snapshots/7ae199353e4324f0216c93c5e8e408d71cb3350a/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...:   0%|                                                                                                                               | 0/5 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/ubuntu/personal/diffusers/test.py", line 5, in <module>
    pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
  File "/home/ubuntu/personal/venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/ubuntu/personal/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 881, in from_pretrained
    loaded_sub_model = load_sub_model(
  File "/home/ubuntu/personal/diffusers/src/diffusers/pipelines/pipeline_loading_utils.py", line 703, in load_sub_model
    loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
  File "/home/ubuntu/personal/venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/ubuntu/personal/diffusers/src/diffusers/models/modeling_utils.py", line 722, in from_pretrained
    model_file = _get_model_file(
  File "/home/ubuntu/personal/venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/ubuntu/personal/diffusers/src/diffusers/utils/hub_utils.py", line 310, in _get_model_file
    raise EnvironmentError(
OSError: Error no file named diffusion_pytorch_model.bin found in directory /home/ubuntu/models/hub/models--maxin-cn--Latte-1/snapshots/7ae199353e4324f0216c93c5e8e408d71cb3350a/vae.

Explanation: The transformer/ directory contains models in .bin format. With Diffusers, once we see a .bin file, we assume that all model files are of the same type. However, the vae/ and text_encoder/ directory contains files in safetensors format, and so the initialization fails since we are now looking for diffusion_pytorch_models.bin instead of diffusion_pytorch_models.safetensors.
Could you update all model files to be in safetensors format?

  • The Latte Fast tests seem to be failing. Could you take a look at how the below could be fixed?

Local logs

...
_________________________________________________________________________ LattePipelineFastTests.test_to_device __________________________________________________________________________

self = <tests.pipelines.latte.test_latte.LattePipelineFastTests testMethod=test_to_device>

    @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
    def test_to_device(self):
        components = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)
    
        pipe.to("cpu")
        model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
        self.assertTrue(all(device == "cpu" for device in model_devices))
    
>       output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]

tests/pipelines/test_pipelines_common.py:1309: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:115: in decorate_context
    return func(*args, **kwargs)
src/diffusers/pipelines/latte/pipeline_latte.py:773: in __call__
    noise_pred = self.transformer(
../venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = LatteTransformer3DModel(
  (pos_embed): PatchEmbed(
    (proj): Conv2d(4, 24, kernel_size=(2, 2), stride=(2, 2))
  )
 ...as=True)
    (act_1): GELU(approximate='tanh')
    (linear_2): Linear(in_features=24, out_features=24, bias=True)
  )
)
hidden_states = tensor([[[-1.2315e+00, -1.5834e+00, -3.1012e-01,  ..., -5.5376e-01,
           2.6413e+00,  9.2141e-01],
         [-8....1.0070e+00],
         [-1.0457e-01,  1.3962e-01,  4.7246e-01,  ...,  1.0995e+00,
           1.5373e+00,  9.4633e-01]]])
timestep = tensor([[ 0.2214,  0.1386,  0.0953,  0.1772,  0.0803,  0.1018,  0.0042, -0.2041,
          0.0433,  0.0170, -0.0644, -... -0.0327, -0.1899,  0.1730,  0.0587,
         -0.0698, -0.0681, -0.0367, -0.1505, -0.0610, -0.1563,  0.1971,  0.1270]])
encoder_hidden_states = tensor([[[-0.1367,  0.1571, -0.1990,  ..., -0.0739, -0.0986,  0.0098],
         [-0.1372,  0.1570, -0.1986,  ..., -0.0...73, -0.1988,  ..., -0.0749, -0.0984,  0.0095],
         [-0.1373,  0.1572, -0.1988,  ..., -0.0753, -0.0988,  0.0099]]])
encoder_attention_mask = None, enable_temporal_attentions = True, return_dict = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        timestep: Optional[torch.LongTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        enable_temporal_attentions: bool = True,
        return_dict: bool = True,
    ):
        """
        The [`LatteTransformer3DModel`] forward method.
    
        Args:
            hidden_states shape `(batch size, channel, num_frame, height, width)`:
                Input `hidden_states`.
            timestep ( `torch.LongTensor`, *optional*):
                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
            encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
                self-attention.
            encoder_attention_mask ( `torch.Tensor`, *optional*):
                Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
    
                    * Mask `(batcheight, sequence_length)` True = keep, False = discard.
                    * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.
    
                If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
                above. This bias will be added to the cross-attention scores.
            enable_temporal_attentions:
                (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.
    
        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """
    
        # Reshape hidden states
        batch_size, channels, num_frame, height, width = hidden_states.shape
        # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
    
        # Input
        height, width = (
            hidden_states.shape[-2] // self.config.patch_size,
            hidden_states.shape[-1] // self.config.patch_size,
        )
        num_patches = height * width
    
        hidden_states = self.pos_embed(hidden_states)  # alrady add positional embeddings
    
        added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
        timestep, embedded_timestep = self.adaln_single(
            timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
        )
    
        # Prepare text embeddings for spatial block
        # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
        encoder_hidden_states = self.caption_projection(encoder_hidden_states)  # 3 120 1152
        encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
            -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
        )
    
        # Prepare timesteps for spatial and temporal block
        timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
        timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
    
        # Spatial and temporal transformer blocks
        for i, (spatial_block, temp_block) in enumerate(
            zip(self.transformer_blocks, self.temporal_transformer_blocks)
        ):
            if self.training and self.gradient_checkpointing:
                hidden_states = torch.utils.checkpoint.checkpoint(
                    spatial_block,
                    hidden_states,
                    None,  # attention_mask
                    encoder_hidden_states_spatial,
                    encoder_attention_mask,
                    timestep_spatial,
                    None,  # cross_attention_kwargs
                    None,  # class_labels
                    use_reentrant=False,
                )
            else:
                hidden_states = spatial_block(
                    hidden_states,
                    None,  # attention_mask
                    encoder_hidden_states_spatial,
                    encoder_attention_mask,
                    timestep_spatial,
                    None,  # cross_attention_kwargs
                    None,  # class_labels
                )
    
            if enable_temporal_attentions:
                # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
                hidden_states = hidden_states.reshape(
                    batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
                ).permute(0, 2, 1, 3)
                hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
    
                if i == 0:
                    hidden_states = hidden_states + self.temp_pos_embed
    
                if self.training and self.gradient_checkpointing:
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        temp_block,
                        hidden_states,
                        None,  # attention_mask
                        None,  # encoder_hidden_states
                        None,  # encoder_attention_mask
                        timestep_temp,
                        None,  # cross_attention_kwargs
                        None,  # class_labels
                        use_reentrant=False,
                    )
                else:
                    hidden_states = temp_block(
                        hidden_states,
                        None,  # attention_mask
                        None,  # encoder_hidden_states
                        None,  # encoder_attention_mask
                        timestep_temp,
                        None,  # cross_attention_kwargs
                        None,  # class_labels
                    )
    
                # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
                hidden_states = hidden_states.reshape(
                    batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
                ).permute(0, 2, 1, 3)
                hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
    
        embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
        shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
        hidden_states = self.norm_out(hidden_states)
        # Modulation
>       hidden_states = hidden_states * (1 + scale) + shift
E       RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0

src/diffusers/models/transformers/latte_transformer_3d.py:307: RuntimeError
================================================================================ short test summary info =================================================================================
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_callback_cfg - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_callback_inputs - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_cfg - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_cpu_offload_forward_pass_twice - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_dict_tuple_outputs_equivalent - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_float16_inference - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_inference - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_inference_batch_consistent - RuntimeError: The size of tensor a (64) must match the size of tensor b (4) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_inference_batch_single_identical - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_model_cpu_offload_forward_pass - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_num_images_per_prompt - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_progress_bar - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_save_load_float16 - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_save_load_local - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_save_load_optional_components - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_sequential_cpu_offload_forward_pass - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_sequential_offload_forward_pass_twice - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
FAILED tests/pipelines/latte/test_latte.py::LattePipelineFastTests::test_to_device - RuntimeError: The size of tensor a (32) must match the size of tensor b (2) at non-singleton dimension 0
  • make fix-copies seems to still be failing even after running it in my PR. I will run it again and commit once we fix other errors. Could you ensure you have ruff version 0.1.5?

Thanks for your help!
Yes, my ruff version is 0.1.5. image
And I will convert the transformer format from bin to safetensors.

Hi @a-r-r-o-w , I have converted the model (https://huggingface.co/maxin-cn/Latte-1/tree/main/transformer).

Hi @a-r-r-o-w , I have fixed the latte text pipeline.

image

Fast tests for PRs / PyTorch Example CPU tests (pull_request) seems to fail.

@a-r-r-o-w
Copy link
Member

Hi @maxin-cn. Seems like the Latte tests are still broken but atleast quality ones passed!

Fast PyTorch Pipeline CPU tests - this test fails in the pipeline _text_preprocessing method

PyTorch Example CPU tests - this test fails due to an unrelated error and you don't need to do anything on your end I think.

Also, with your latest commit reduce frames for test, pretty much all tests are failing with tensor size mismatch errors that I mentioned in a previous comment. Once these are addressed, we can merge this 💯

@a-r-r-o-w
Copy link
Member

Also, you might be interested in 4/8-bit quantization for the entire model or memory optimization for the T5 text encoder. I am trying a few things here. Presently, inference is extremely slow due to the overhead involved but if you have any insights experimenting on the same, it would be awesome to know! As such, float16 inference runs in < 17 GB for 16-frame video but I'm interested in bringing that down to 8-10 GB.

@maxin-cn
Copy link
Contributor Author

_text_preprocessing

Hi @a-r-r-o-w , I have fixed the _text_preprocessing bug for testing.

image

@maxin-cn
Copy link
Contributor Author

_text_preprocessing

Hi @a-r-r-o-w , I have fixed the _text_preprocessing bug for testing.

image

And I also tested video_length=1. All the tests passed.
image

@maxin-cn
Copy link
Contributor Author

Also, you might be interested in 4/8-bit quantization for the entire model or memory optimization for the T5 text encoder. I am trying a few things here. Presently, inference is extremely slow due to the overhead involved but if you have any insights experimenting on the same, it would be awesome to know! As such, float16 inference runs in < 17 GB for 16-frame video but I'm interested in bringing that down to 8-10 GB.

Thanks for your suggestions. I did not try some quantization techniques at this moment. BTW, is this quantitative reasoning necessary to merge this PR? If not, I want to integrate your code into subsequent inference code after merging Latte's PR.

@a-r-r-o-w
Copy link
Member

Thanks for your suggestions. I did not try some quantization techniques at this moment. BTW, is this quantitative reasoning necessary to merge this PR? If not, I want to integrate your code into subsequent inference code after merging Latte's PR.

No, quantization is not necessary at all. This PR looks absolutely great to merge now and once CI passes, we can do it.

@maxin-cn
Copy link
Contributor Author

Thanks for your suggestions. I did not try some quantization techniques at this moment. BTW, is this quantitative reasoning necessary to merge this PR? If not, I want to integrate your code into subsequent inference code after merging Latte's PR.

No, quantization is not necessary at all. This PR looks absolutely great to merge now and once CI passes, we can do it.

Okay. Let's try quantization later.

@a-r-r-o-w
Copy link
Member

Thank you for bearing with our reviews/requests over the duration of this PR, and being so quick to respond! This is very cool work ❤️

LGTM! 🤗

@a-r-r-o-w a-r-r-o-w merged commit b8cf84a into huggingface:main Jul 11, 2024
15 checks passed
@maxin-cn
Copy link
Contributor Author

Thank you for bearing with our reviews/requests over the duration of this PR, and being so quick to respond! This is very cool work ❤️

LGTM! 🤗

I would like to extend my heartfelt thanks for your support and responsiveness throughout the duration of this PR!

@maxin-cn
Copy link
Contributor Author

Thanks for your suggestions. I did not try some quantization techniques at this moment. BTW, is this quantitative reasoning necessary to merge this PR? If not, I want to integrate your code into subsequent inference code after merging Latte's PR.

No, quantization is not necessary at all. This PR looks absolutely great to merge now and once CI passes, we can do it.

Okay. Let's try quantization later.

Hi @a-r-r-o-w ! I have added your code for quantization inference. You can find it at here. Thank you very much.

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* add Latte to diffusers

* remove print

* remove print

* remove print

* remove unuse codes

* remove layer_norm_latte and add a flag

* remove layer_norm_latte and add a flag

* update latte_pipeline

* update latte_pipeline

* remove unuse squeeze

* add norm_hidden_states.ndim == 2: # for Latte

* fixed test latte pipeline bugs

* fixed test latte pipeline bugs

* delete sh

* add doc for latte

* add licensing

* Move Transformer3DModelOutput to modeling_outputs

* give a default value to sample_size

* remove the einops dependency

* change norm2 for latte

* modify pipeline of latte

* update test for Latte

* modify some codes for latte

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* modify for Latte pipeline

* video_length -> num_frames; update prepare_latents copied from

* make fix-copies

* make style

* typo: videe -> video

* update

* modify for Latte pipeline

* modify latte pipeline

* modify latte pipeline

* modify latte pipeline

* modify latte pipeline

* modify for Latte pipeline

* Delete .vscode directory

* make style

* make fix-copies

* add latte transformer 3d to docs _toctree.yml

* update example

* reduce frames for test

* fixed bug of _text_preprocessing

* set num frame to 1 for testing

* remove unuse print

* add text = self._clean_caption(text) again

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Aryan <[email protected]>
Co-authored-by: Aryan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants