Official Pytorch implementation of TerDiT as presented in paper:
TerDiT: Ternary Diffusion Models with Transformers
Xudong Lu, Aojun Zhou, Ziyi Lin, Qi Liu, Yuhui Xu, Renrui Zhang, Yafei Wen, Shuai Ren, Peng Gao, Junchi Yan, Hongsheng Li
paper
Step 1: Please first follow the installation instruction of Large-DiT-ImageNet. We use cuda-12.1
and gcc 9.4.0
.
Step 2: Install relevant packages
- replace fairscale with our modified package
cd fairscale pip install -e .
- Install 2-bit kernal function from hqq.
cd kernel python3 setup_cuda.py install
We offer trained checkpoints for ImageNet generation at both 256x256 and 512x512 resolutions. For 256x256 resolution, we provide models with 600M and 4.2B parameters. For 512x512 resolution, we offer the 4.2B model. You can find the checkpoints on huggingface 🤗.
Run the code for sampling (assume your ckpt is downloaded in checkpoints/3B_1180000
):
python -u sample.py --ckpt checkpoints/3B_1180000 --local_diffusers_model_root /path/to/diffusers_models --seed 42
TerDiT achieves competitive results with full-precision models. Notably, TerDiT-4.2B outperforms the original DiT on the ImageNet 512x512 task.
This project is released under the MIT license. Please see the LICENSE file for more information.
Feel free to discuss papers/code with us through issues/emails!
- Xudong Lu: [email protected]
If you find our paper and code useful in your research, please cite
@misc{lu2024terdit,
title={TerDiT: Ternary Diffusion Models with Transformers},
author={Xudong Lu and Aojun Zhou and Ziyi Lin and Qi Liu and Yuhui Xu and Renrui Zhang and Yafei Wen and Shuai Ren and Peng Gao and Junchi Yan and Hongsheng Li},
year={2024},
eprint={2405.14854},
archivePrefix={arXiv},
primaryClass={cs.CV}
}