Unofficial implementation of MaskGIT in PyTorch and PyTorch Lightning.
Generated with a model trained for 100 epochs on ImageNet tokenized with a VQGAN (f=16, 1024 codebook size) from taming-transformers. Images have been downscaled to 128x128 pixels for faster training.
Example reconstructions of the original data for comparison:
Install dependencies
poetry install
Tokenize dataset
Download VQGAN checkpoint from https://github.com/CompVis/taming-transformers to models/<model_name>/
.
Encode dataset:
poetry run python extract_latents.py \
--config_path=models/<model-name>/config.yaml \
--ckpt_path=models/<model-name>/model.ckpt \
--data_root=<path-to-images> \
--save_path=<codes-path>
<path-to-images>
has to contain a train
, val
and test
folder.
Run training:
poetry run python train.py fit \
--config=config/config.yaml \
--config=config/data/<data-config>.yaml \
--config=config/models/transformer/base.yaml \
--data.root=<codes-path> \
--trainer.accelerator=gpu --trainer.devices=1 --trainer.precision=16
Logging to W&B can be enabled by adding the following line to the training command:
--config=config/loggers/wandb.yaml
- Improve sampling
- Large scale training
Chang, Huiwen, Han Zhang, Lu Jiang, Ce Liu, and William T. Freeman. “MaskGIT: Masked Generative Image Transformer.” arXiv, February 8, 2022. https://doi.org/10.48550/arXiv.2202.04200.