Pytorch implementation for our DivCo. We propose a simple yet effective regularization term named latent-augmented contrastive loss that can be applied to arbitrary conditional generative adversarial networks in different tasks to alleviate the mode collapse issue and improve the diversity.
Contact: Rui Liu ([email protected])
DivCo: Diverse Conditional Image Synthesis via Contrastive Generative Adversarial Network
Rui Liu, Yixiao Ge, Ching Lam Choi, Xiaogang Wang, and Hongsheng Li
IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2021
[arxiv]
If you find DivCo useful in your research, please consider citing:
@inproceedings{Liu_DivCo,
author = {Liu, Rui and Ge, Yixiao and Choi, Ching Lam and Wang, Xiaogang and Li, Hongsheng},
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
title = {DivCo: Diverse Conditional Image Synthesis via Contrastive Generative Adversarial Network},
year = {2021}
}
- Python >= 3.6
- Pytorch >= 0.4.0 and corresponding torchvision (https://pytorch.org/)
- Clone this repo:
git clone https://github.com/ruiliu-ai/DivCo.git
Download datasets for each task into the dataset folder
mkdir datasets
- Dataset: CIFAR-10
- Baseline: DCGAN
cd DivCo/DivCo-DCGAN
python train.py --dataroot ./datasets/Cifar10
- Paired Data: facades and maps
- Baseline: BicycleGAN
You can download the facades and maps datasets from the BicycleGAN [Github Project].
We employ the network architecture of the BicycleGAN and follow its training process.
cd DivCo/DivCo-BicycleGAN
python train.py --dataroot ./datasets/facades
- Unpaired Data: Yosemite (summer <-> winter) and Cat2Dog (cat <-> dog)
- Baseline: DRIT
You can download the datasets from the DRIT [Github Project].
Specify --concat 0
for Cat2Dog to handle large shape variation translation
cd DivCo/DivCo-DRIT
python train.py --dataroot ./datasets/cat2dog --concat 0 --lambda_contra 0.1
python train.py --dataroot ./datasets/yosemite --concat 1 --lambda_contra 1.0
Download and save them into
./models/
For BicycleGAN, DRIT and MSGAN, please follow the instructions of corresponding github projects of the baseline frameworks for more evaluation details.
DivCo-DCGAN
python test.py --dataroot ./datasets/Cifar10 --resume ./models/DivCo-DCGAN/00199.pth
DivCo-BicycleGAN
python test.py --dataroot ./datasets/facades --checkpoints_dir ./models/DivCo-BicycleGAN/facades --epoch 400
python test.py --dataroot ./datasets/maps --checkpoints_dir ./models/DivCo-BicycleGAN/maps --epoch 400
DivCo-DRIT
python test.py --dataroot ./datasets/yosemite --resume ./models/DivCo-DRIT/yosemite/01199.pth --concat 1
python test.py --dataroot ./datasets/cat2dog --resume ./models/DivCo-DRIT/cat2dog/01199.pth --concat 0
- DCGAN [Paper]
- BicycleGAN [Github Project]
- DRIT [Github Project]
- MSGAN [Github Project]
- FID [Github Project]
- LPIPS [Github Project]
- NDB and JSD [Github Project]