Skip to content

ChenyangSi/FreeU

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

95 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FreeU: Free Lunch in Diffusion U-Net

S-Lab, Nanyang Technological University

Paper | Project Page | Video | Demo

CVPR2024 Oral

Twitter Hits Hits Hits Hugging Face


We propose FreeU, a method that substantially improves diffusion model sample quality at no cost: no training, no additional parameter introduced, and no increase in memory or sampling time.

📖 For more visual results, go checkout our Project Page

Usage

  • A demo is also available on the Hugging Face (huge thanks to AK and all the HF team for their support).
  • You can use the gradio demo locally by running python demos/app.py.

FreeU Code

def Fourier_filter(x, threshold, scale):
    # FFT
    x_freq = fft.fftn(x, dim=(-2, -1))
    x_freq = fft.fftshift(x_freq, dim=(-2, -1))
    
    B, C, H, W = x_freq.shape
    mask = torch.ones((B, C, H, W)).cuda() 

    crow, ccol = H // 2, W //2
    mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
    x_freq = x_freq * mask

    # IFFT
    x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
    x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
    
    return x_filtered

class Free_UNetModel(UNetModel):
    """
    :param b1: backbone factor of the first stage block of decoder.
    :param b2: backbone factor of the second stage block of decoder.
    :param s1: skip factor of the first stage block of decoder.
    :param s2: skip factor of the second stage block of decoder.
    """

    def __init__(
        self,
        b1,
        b2,
        s1,
        s2,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.b1 = b1 
        self.b2 = b2
        self.s1 = s1
        self.s2 = s2

    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs = []
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)
        h = self.middle_block(h, emb, context)
        for module in self.output_blocks:
            hs_ = hs.pop()

            # --------------- FreeU code -----------------------
            # Only operate on the first two stages
            if h.shape[1] == 1280:
                hidden_mean = h.mean(1).unsqueeze(1)
                B = hidden_mean.shape[0]
                hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
                hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
                hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)

                h[:,:640] = h[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
                hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
            if h.shape[1] == 640:
                hidden_mean = h.mean(1).unsqueeze(1)
                B = hidden_mean.shape[0]
                hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
                hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
                hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)

                h[:,:320] = h[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
                hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
            # ---------------------------------------------------------

            h = th.cat([h, hs_], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

Parameters

You can adjust these parameters based on your models, image/video style, or tasks. You can look over the following parameters.

SD1.4: (will be updated soon)

b1: 1.3, b2: 1.4, s1: 0.9, s2: 0.2

SD1.5: (will be updated soon)

b1: 1.5, b2: 1.6, s1: 0.9, s2: 0.2

SD2.1

b1: 1.1, b2: 1.2, s1: 0.9, s2: 0.2

b1: 1.4, b2: 1.6, s1: 0.9, s2: 0.2

SDXL

b1: 1.3, b2: 1.4, s1: 0.9, s2: 0.2 SDXL results

Range for More Parameters

When trying additional parameters, consider the following ranges:

  • b1: 1 ≤ b1 ≤ 1.2
  • b2: 1.2 ≤ b2 ≤ 1.6
  • s1: s1 ≤ 1
  • s2: s2 ≤ 1

Results from the community

If you tried FreeU and want to share your results, let me know and we can put up the link here.

BibTeX

@inproceedings{si2023freeu,
  title={FreeU: Free Lunch in Diffusion U-Net},
  author={Si, Chenyang and Huang, Ziqi and Jiang, Yuming and Liu, Ziwei},
  booktitle={CVPR},
  year={2024}
}

🗞️ License

Distributed under the MIT License. See LICENSE for more information.