This repository provides a minimalistic implementation of our approach to training large-scale diffusion models from scratch on an extremely low budget. In particular, using only 37M publicly available real and synthetic images, we train a 1.16 billion parameter sparse transformer with a total cost of only $1,890 and achieve an FID of 12.7 in zero-shot generation on the COCO dataset. Please find further details in the official paper: https://arxiv.org/abs/2407.15811.
Prompt: 'Image of an astronaut riding a horse in {} style'. Styles: Origami, Pixel art, Line art, Cyberpunk, Van Gogh Starry Night, Animation, Watercolor, Stained glassThe current codebase enables the reproduction of our models, as it provides both the training code and the dataset code used in the paper. We also provide pre-trained model checkpoints for off-the-shelf generation.
Clone the repository and install micro_diffusion as a Python package using the following command.
pip install -e .
First, set up the necessary directories to store the datasets and training artifacts.
ln -s path_to_dir_that_stores_data ./datadir
ln -s path_to_dir_that_stores_trained_models ./trained_models
Next, download the training dataset by following datasets.md. It provides instructions on downloading and precomputing the image and caption latents for all five datasets used in the paper. As downloading and precomputing latents for entire datasets can be time-consuming, it supports downloading and using only a small (~1%) fraction of each dataset, instead of the full dataset. We strongly recommend working with this small subset for initial experimentation.
We progressively train the models from low resolution to high resolution. We first train the model on 256×256 resolution images for 280K steps and then fine-tune the model for 55K steps on 512×512 resolution images. The estimated training time for the end-to-end model on an 8×H100 machine is 2.6 days. We provide the training configuration for each stage in the ./yamls/ directory.
Patch masking: Our DiT model by default uses a patch-mixer before the backbone transformer architecture. Using the patch-mixer significantly reduces performance degradation with masking while providing a large reduction in training time. We mask 75% of the patches after the patch mixer across both resolutions. After training with masking, we perform a follow-up fine-tuning with a mask ratio of 0 to slightly improve performance. Overall, the training includes four steps.
composer train.py --config-path ./configs --config-name res_256_pretrain.yaml exp_name=MicroDiTXL_mask_75_res_256_pretrain model.train_mask_ratio=0.75
composer train.py --config-path ./configs --config-name res_256_finetune.yaml exp_name=MicroDiTXL_mask_0_res_256_finetune model.train_mask_ratio=0.0 trainer.load_path=./trained_models/MicroDiTXL_mask_75_res_256_pretrain/latest-rank0.pt
composer train.py --config-path ./configs --config-name res_512_pretrain.yaml exp_name=MicroDiTXL_mask_75_res_512_pretrain model.train_mask_ratio=0.75 trainer.load_path=./trained_models/MicroDiTXL_mask_0_res_256_finetune/latest-rank0.pt
composer train.py --config-path ./configs --config-name res_512_finetune.yaml exp_name=MicroDiTXL_mask_0_res_512_finetune model.train_mask_ratio=0.0 trainer.load_path=./trained_models/MicroDiTXL_mask_75_res_512_pretrain/latest-rank0.pt
Across all steps, we use a batch size of 2,048, apply center cropping, and do not horizontally flip the images.
We release four pre-trained models (HF). The table below provides download links and a description of each model.
Model Description | VAE (channels) | FID | GenEval Score | Download |
---|---|---|---|---|
MicroDiT_XL_2 trained on 22M real images | SDXL-VAE (4 channel) | 12.72 | 0.46 | link |
MicroDiT_XL_2 trained on 37M images (22M real, 15M synthetic) | SDXL-VAE (4 channel) | 12.66 | 0.46 | link |
MicroDiT_XL_2 trained on 37M images (22M real, 15M synthetic) | Ostris-VAE (16 channel) | 13.04 | 0.40 | link |
MicroDiT_XL_2 trained on 490M synthetic images | SDXL-VAE (4 channel) | 13.26 | 0.52 | link |
All four models are trained with nearly identical training configurations and computational budgets.
Use the following straightforward steps to generate images from the final model at 512×512 resolution.
from micro_diffusion.models.model import create_latent_diffusion
model = create_latent_diffusion(latent_res=64, in_channels=4, pos_interp_scale=2.0).to('cuda')
model.dit.load_state_dict(torch.load(final_ckpt_path_on_local_disk)) # use model.load_state_dict if ckpt includes vae and text-encoder
gen_images = model.generate(prompt=['An elegant squirrel pirate on a ship']*4, num_inference_steps=30,
guidance_scale=5.0, seed=2024)
We would like to thank previous open-source efforts that we utilize in our code, in particular, the Composer training framework, streaming dataloaders, the Diffusers library, and the original Diffusion Transformers (DiT) implementation.
The code and model weights are released under Apache 2.0 License.
@article{Sehwag2024MicroDiT,
title={Stretching Each Dollar: Diffusion Training from Scratch on a Micro-Budget},
author={Sehwag, Vikash and Kong, Xianghao and Li, Jingtao and Spranger, Michael and Lyu, Lingjuan},
journal={arXiv preprint arXiv:2407.15811},
year={2024}
}