- Built on top of StyleGANv2 ADA implementation
Compatibility
- Compatible with old network pickles created using the TensorFlow version.
- New ZIP/PNG based dataset format for maximal interoperability with existing 3rd party tools.
- TFRecords datasets are no longer supported — they need to be converted to the new format.
- New JSON-based format for logs, metrics, and training curves.
- Training curves are also exported in the old TFEvents format if TensorBoard is installed.
- Command line syntax is mostly unchanged, with a few exceptions (e.g.,
dataset_tool.py
). - Comparison methods are not supported (
--cmethod
,--dcap
,--cfg=cifarbaseline
,--aug=adarv
) - Truncation is now disabled by default.
Path | Description |
---|---|
stylegan2-ada-pytorch | Main directory hosted on Amazon S3 |
├ ada-paper.pdf | Paper PDF |
├ images | Curated example images produced using the pre-trained models |
├ videos | Curated example interpolation videos |
└ pretrained | Pre-trained models |
├ ffhq.pkl | FFHQ at 1024x1024, trained using original StyleGAN2 |
├ brecahad.pkl | BreCaHAD at 512x512, trained from scratch using ADA |
├ paper-fig7c-training-set-sweeps | Models used in Fig.7c (sweep over training set size) |
├ paper-fig11a-small-datasets | Models used in Fig.11a (small datasets & transfer learning) |
├ paper-fig11b-cifar10 | Models used in Fig.11b (CIFAR-10) |
├ transfer-learning-source-nets | Models used as starting point for transfer learning |
└ metrics | Feature detectors used by the quality metrics |
Pre-trained networks are stored as *.pkl
files that can be referenced using local filenames or URLs:
# Generate curated MetFaces images without truncation (Fig.10 left)
python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
# Generate uncurated MetFaces images with truncation (Fig.12 upper left)
python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
# Generate class conditional CIFAR-10 images (Fig.17 left, Car)
python generate.py --outdir=out --seeds=0-35 --class=1 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
# Style mixing example
python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
Outputs from the above commands are placed under out/*.png
, controlled by --outdir
. Downloaded network pickles are cached under $HOME/.cache/dnnlib
, which can be overridden by setting the DNNLIB_CACHE_DIR
environment variable. The default PyTorch extension build directory is $HOME/.cache/torch_extensions
, which can be overridden by setting TORCH_EXTENSIONS_DIR
.
To find the matching latent vector for a given image file, run:
python projector.py --outdir=out --target=~/mytargetimg.png \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
For optimal results, the target image should be cropped and aligned similar to the FFHQ dataset. The above command saves the projection target out/target.png
, result out/proj.png
, latent vector out/projected_w.npz
, and progression video out/proj.mp4
. You can render the resulting latent vector by specifying --projected_w
for generate.py
:
python generate.py --outdir=out --projected_w=out/projected_w.npz \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
You can use pre-trained networks in your own Python code as follows:
with open('ffhq.pkl', 'rb') as f:
G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
z = torch.randn([1, G.z_dim]).cuda() # latent codes
c = None # class labels (not used in this example)
img = G(z, c) # NCHW, float32, dynamic range [-1, +1]
The above code requires torch_utils
and dnnlib
to be accessible via PYTHONPATH
. It does not need source code for the networks themselves — their class definitions are loaded from the pickle via torch_utils.persistence
.
The pickle contains three networks. 'G'
and 'D'
are instantaneous snapshots taken during training, and 'G_ema'
represents a moving average of the generator weights over several training steps. The networks are regular instances of torch.nn.Module
, with all of their parameters and buffers placed on the CPU at import and gradient computation disabled by default.
The generator consists of two submodules, G.mapping
and G.synthesis
, that can be executed separately. They also support various additional options:
w = G.mapping(z, c, truncation_psi=0.5, truncation_cutoff=8)
img = G.synthesis(w, noise_mode='const', force_fp32=True)
Please refer to generate.py
, style_mixing.py
, and projector.py
for further examples.
Datasets are stored as uncompressed ZIP archives containing uncompressed PNG files and a metadata file dataset.json
for labels.
Custom datasets can be created from a folder containing images; see python dataset_tool.py --help
for more information. Alternatively, the folder can also be used directly as a dataset, without running it through dataset_tool.py
first, but doing so may lead to suboptimal performance.
Legacy TFRecords datasets are not supported — see below for instructions on how to convert them.
FFHQ:
Step 1: Download the Flickr-Faces-HQ dataset as TFRecords.
Step 2: Extract images from TFRecords using dataset_tool.py
from the TensorFlow version of StyleGAN2-ADA:
# Using dataset_tool.py from TensorFlow version at
# https://github.com/NVlabs/stylegan2-ada/
python ../stylegan2-ada/dataset_tool.py unpack \
--tfrecord_dir=~/ffhq-dataset/tfrecords/ffhq --output_dir=/tmp/ffhq-unpacked
Step 3: Create ZIP archive using dataset_tool.py
from this repository:
# Original 1024x1024 resolution.
python dataset_tool.py --source=/tmp/ffhq-unpacked --dest=~/datasets/ffhq.zip
# Scaled down 256x256 resolution.
python dataset_tool.py --source=/tmp/ffhq-unpacked --dest=~/datasets/ffhq256x256.zip \
--width=256 --height=256
In its most basic form, training new networks boils down to:
python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 --dry-run
python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1
The first command is optional; it validates the arguments, prints out the training configuration, and exits. The second command kicks off the actual training.
In this example, the results are saved to a newly created directory ~/training-runs/<ID>-mydataset-auto1
, controlled by --outdir
. The training exports network pickles (network-snapshot-<INT>.pkl
) and example images (fakes<INT>.png
) at regular intervals (controlled by --snap
). For each pickle, it also evaluates FID (controlled by --metrics
) and logs the resulting scores in metric-fid50k_full.jsonl
(as well as TFEvents if TensorBoard is installed).
The name of the output directory reflects the training configuration. For example, 00000-mydataset-auto1
indicates that the base configuration was auto1
, meaning that the hyperparameters were selected automatically for training on one GPU. The base configuration is controlled by --cfg
:
Base config | Description |
---|---|
auto (default) |
Automatically select reasonable defaults based on resolution and GPU count. Serves as a good starting point for new datasets but does not necessarily lead to optimal results. |
stylegan2 |
Reproduce results for StyleGAN2 config F at 1024x1024 using 1, 2, 4, or 8 GPUs. |
paper256 |
Reproduce results for FFHQ and LSUN Cat at 256x256 using 1, 2, 4, or 8 GPUs. |
paper512 |
Reproduce results for BreCaHAD and AFHQ at 512x512 using 1, 2, 4, or 8 GPUs. |
paper1024 |
Reproduce results for MetFaces at 1024x1024 using 1, 2, 4, or 8 GPUs. |
cifar |
Reproduce results for CIFAR-10 (tuned configuration) using 1 or 2 GPUs. |
The training configuration can be further customized with additional command line options:
--aug=noaug
disables ADA.--cond=1
enables class-conditional training (requires a dataset with labels).--mirror=1
amplifies the dataset with x-flips. Often beneficial, even with ADA.--resume=ffhq1024 --snap=10
performs transfer learning from FFHQ trained at 1024x1024.--resume=~/training-runs/<NAME>/network-snapshot-<INT>.pkl
resumes a previous training run.--gamma=10
overrides R1 gamma. We recommend trying a couple of different values for each new dataset.--aug=ada --target=0.7
adjusts ADA target value (default: 0.6).--augpipe=blit
enables pixel blitting but disables all other augmentations.--augpipe=bgcfnc
enables all available augmentations (blit, geom, color, filter, noise, cutout).
Please refer to python train.py --help
for the full list.
References:
- GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, Heusel et al. 2017
- Demystifying MMD GANs, Bińkowski et al. 2018
- Improved Precision and Recall Metric for Assessing Generative Models, Kynkäänniemi et al. 2019
- Improved Techniques for Training GANs, Salimans et al. 2016
- A Style-Based Generator Architecture for Generative Adversarial Networks, Karras et al. 2018
@inproceedings{Karras2020ada,
title = {Training Generative Adversarial Networks with Limited Data},
author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
booktitle = {Proc. NeurIPS},
year = {2020}
}
@inproceedings{zhao2021comodgan,
title={Large Scale Image Completion via Co-Modulated Generative Adversarial Networks},
author={Zhao, Shengyu and Cui, Jonathan and Sheng, Yilun and Dong, Yue and Liang, Xiao and Chang, Eric I and Xu, Yan},
booktitle={International Conference on Learning Representations (ICLR)},
year={2021}
}