Skip to content

Commit

Permalink
[bug] Precedence of operations in VAE should be slicing -> tiling (hu…
Browse files Browse the repository at this point in the history
…ggingface#9342)

* bugfix: precedence of operations should be slicing -> tiling

* fix typo

* fix another typo

* deprecate current implementation of tiled_encode and use new impl

* Update src/diffusers/models/autoencoders/autoencoder_kl.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
3 people authored and 蒋硕 committed Oct 11, 2024
1 parent d65e2af commit 5dbf8f6
Showing 1 changed file with 71 additions and 11 deletions.
82 changes: 71 additions & 11 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -245,6 +246,18 @@ def set_default_attn_processor(self):

self.set_attn_processor(processor)

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape

if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)

enc = self.encoder(x)
if self.quant_conv is not None:
enc = self.quant_conv(enc)

return enc

@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
Expand All @@ -261,21 +274,13 @@ def encode(
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)

if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)

if self.quant_conv is not None:
moments = self.quant_conv(h)
else:
moments = h
h = self._encode(x)

posterior = DiagonalGaussianDistribution(moments)
posterior = DiagonalGaussianDistribution(h)

if not return_dict:
return (posterior,)
Expand Down Expand Up @@ -337,6 +342,54 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""

overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent

# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
if self.config.use_quant_conv:
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))

enc = torch.cat(result_rows, dim=2)
return enc

def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Expand All @@ -356,6 +409,13 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
deprecation_message = (
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
)
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)

overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
Expand Down

0 comments on commit 5dbf8f6

Please sign in to comment.