Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow DDPMPipeline half precision #9222

Merged
merged 2 commits into from
Sep 23, 2024

Conversation

sbinnee
Copy link
Contributor

@sbinnee sbinnee commented Aug 20, 2024

What does this PR do?

I found that DDPMPipeline couldn't run half precision.

Test code

import torch
from diffusers import DDPMPipeline

pipeline = DDPMPipeline.from_pretrained('google/ddpm-celebahq-256', torch_dtype=torch.float16)
out = pipeline()

Error message

  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], line 1
----> 1 out = pipeline()

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/workspace/diffusers/src/diffusers/pipelines/ddpm/pipeline_ddpm.py:114, in DDPMPipeline.__call__(self, batch_size, generator, num_inference_steps, output_type, return_dict)
    110 self.scheduler.set_timesteps(num_inference_steps)
    112 for t in self.progress_bar(self.scheduler.timesteps):
    113     # 1. predict noise model_output
--> 114     model_output = self.unet(image, t).sample
    116     # 2. compute previous image: x_t -> x_t-1
    117     image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/workspace/diffusers/src/diffusers/models/unets/unet_2d.py:303, in UNet2DModel.forward(self, sample, timestep, class_labels, return_dict)
    301 # 2. pre-process
    302 skip_sample = sample
--> 303 sample = self.conv_in(sample)
    305 # 3. down
    306 down_block_res_samples = (sample,)

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/conv.py:458, in Conv2d.forward(self, input)
    457 def forward(self, input: Tensor) -> Tensor:
--> 458     return self._conv_forward(input, self.weight, self.bias)

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/conv.py:454, in Conv2d._conv_forward(self, input, weight, bias)
    450 if self.padding_mode != 'zeros':
    451     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    452                     weight, bias, self.stride,
    453                     _pair(0), self.dilation, self.groups)
--> 454 return F.conv2d(input, weight, bias, self.stride,
    455                 self.padding, self.dilation, self.groups)

RuntimeError: Input type (float) and bias type (c10::Half) should be the same

I compared the implementation of DDPM to that of DDIM since they are alike. DDPM simply does not take unet.dtype when initializing the noise.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 19, 2024
@asomoza asomoza requested a review from yiyixuxu September 19, 2024 18:39
@asomoza asomoza removed the stale Issues that haven't received updates label Sep 19, 2024
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu merged commit 3e69e24 into huggingface:main Sep 23, 2024
15 checks passed
leisuzz pushed a commit to leisuzz/diffusers that referenced this pull request Oct 11, 2024
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants