@inproceedings{liu2021content,
title = {Content-Aware GAN Compression},
author = {Liu, Yuchen and Shu, Zhixin and Li, Yijun and Lin, Zhe and Perazzi, Federico and Kung, S.Y.},
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2021},
}
We propose a novel content-aware approach for GAN compression. With content-awareness, our 11x-accelerated GAN performs comparably with the full-size model on image generation and image editing.
We show an example above on the generative ability of our 11x-accelerated generator vs. the full-size one. In particular, our model generates the interested contents visually comparable to the full-size model.
We show an example typifying the effectiveness of our compressed StyleGAN2 for image style-mixing and morphing above. When we mix middle styles from B, the original full-size model has a significant identity loss, while our approach better preserves the person’s identity. We also observe that our morphed images have a smoother expression transition compared the full-size model in the beard, substantiating our advantage in latent space smoothness.
We provide an additional example above.
In our work, we make the first attempt to bring content awareness into channel pruning and knowledge distillation.
Specifically, we leverage a content-parsing network to identify contents of interest (COI), a set of spatial locations with salient semantic concepts, within the generated images. We design a content-aware pruning metric (with a forward and backward path) to remove channels that are least sensitive to COI in the generated images. For knowledge distillation, we focus our distillation region only to COI of the teacher’s outputs which further enhances target contents’ distillation.
We have tested our codes under the following environments:
python == 3.6.5
pytorch == 1.6.0
torchvision == 0.7.0
CUDA == 10.2
To start with, you can first download a full-size generator checkpoint from:
and place it under the folder ./Model/full_size_model/
.
Once you get the full-size checkpoint, you can prune the generator by:
python3 prune.py \
--generated_img_size=256 \
--ckpt=/path/to/full/size/model/ \
--remove_ratio=0.7 \
--info_print
We adopt a uniform channel pruning ratio for every layer. Above procedure will remove 70% of channels from the generator in each layer. The pruned checkpoint will be saved at ./Model/pruned_model/
.
We then retrain the pruned generator by:
python3 train.py \
--size=256 \
--path=/path/to/ffhq/data/folder/ \
--ckpt=/path/to/pruned/model/ \
--teacher_ckpt=/path/to/full/size/model/ \
--iter=450001 \
--batch_size=16
You may adjust the variables gpu_device_ids
and primary_device
for the GPU setup in train_hyperparams.py
.
The time for retraining 11x-compressed models on V100 GPUs:
Model | Batch Size | Iterations | # GPUs | Time (Hour) |
---|---|---|---|---|
256px StyleGAN2 | 16 | 450k | 2 | 131 |
1024px StyleGAN2 | 16 | 450k | 4 | 251 |
A typical training curve for the 11x-compressed 256px StyleGAN2:
To evaluate the model quantitatively, we provide get_fid.py
and get_ppl.py
to get model's FID and PPL sores.
FID Evaluation:
python3 get_fid.py \
--generated_img_size=256 \
--ckpt=/path/to/model/ \
--n_sample=50000 \
--batch_size=64 \
--info_print
PPL Evaluation:
python3 get_ppl.py \
--generated_img_size=256 \
--ckpt=/path/to/model/ \
--n_sample=5000 \
--eps=1e-4 \
--info_print
We also provide an image projector which return a (real image, projected image) pair in Image_Projection_Visualization.png
as well as the PSNR and LPIPS score between this pair:
python3 get_projected_image.py \
--generated_img_size=256 \
--ckpt=/path/to/model/ \
--image_file=/path/to/an/RGB/image/ \
--num_iters=800 \
--info_print
An example of Image_Projection_Visualization.png
projected by a full-size 256px StyleGAN2:
We provide the Helen-Set55 on Google Drive.
We provide the following checkpoints of our content-aware compressed StyleGAN2:
PyTorch StyleGAN2: https://github.com/rosinality/stylegan2-pytorch
Face Parsing BiSeNet: https://github.com/zllrunning/face-parsing.PyTorch
Fréchet Inception Distance: https://github.com/mseitzer/pytorch-fid
Learned Perceptual Image Patch Similarity: https://github.com/richzhang/PerceptualSimilarity