Skip to content

EnVision-Research/OmniBooth

Repository files navigation

OmniBooth

OmniBooth: Learning Latent Control for Image Synthesis with Multi-modal Instruction
Leheng Li, Weichao Qiu, Xu Yan, Jing He, Kaiqiang Zhou, Yingjie Cai, Qing Lian, Bingbing Liu, Ying-Cong Chen

OmniBooth is a project focused on synthesizing image data following multi-modal instruction. Users can use text or image to control instance generation. This repository provides tools and scripts to process, train, and generate synthetic image data using COCO dataset, or self-designed data.

Table of Contents

Installation

To get started with OmniBooth, follow these steps:

  1. Clone the repository:

    git clone https://github.com/Len-Li/OmniBooth.git
    cd OmniBooth
  2. Set up a environment :

    pip install torch torchvision transformers open_clip_torch
    pip install diffusers==0.26.0.dev0 
    # We use a old version of diffusers, please take care of it.
    
    pip install albumentations pycocotools 
    pip install git+https://github.com/cocodataset/panopticapi.git

Prepare Dataset

You can skip this step if you just want to run a demo generation. I've prepared demo mask in data/instance_dataset for generation. Please see Inference.

To train OmniBooth, follow the steps below:

  1. Download the COCONut dataset:

    We use COCONut-S split. Please download the COCONut-S file and relabeled-COCO-val from here and put it in data/coconut_dataset folder. I recommend to use Kaggle link.

  2. Download the COCO dataset:

    cd data/coconut_dataset 
    mkdir coco && cd coco
    
    wget http://images.cocodataset.org/zips/train2017.zip
    wget http://images.cocodataset.org/zips/val2017.zip
    wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
    
    unzip train2017.zip && unzip val2017.zip
    unzip annotations_trainval2017.zip
    

    Then, please download the instance prompt from hf.

    After preparation, you will be able to see the following directory structure:

    OmniBooth/
    ├── data/
    │   ├── instance_dataset/
    │   ├── coconut_dataset/
    │   │   ├── coco/
    │   │   ├── coconut_s/
    |   |   ├── relabeled_coco_val/
    │   │   ├── annotations/
    │   │   │   ├── coconut_s.json
    │   │   │   ├── relabeled_coco_val.json
    │   │   │   ├── my-train.json
    │   │   │   ├── my-val.json
    

Prepare Checkpoint

Our model is based on stable-diffusion-xl-base-1.0. We additionaly use sdxl-vae-fp16-fix to avoid numerical issue in VAE decoding. Please download the two models and put them at ./OmniBooth/ckp/.

Download dinov2_vitl14_reg4_pretrain.pth for image feature extraction. Link. Also put it at ./OmniBooth/ckp/.

Our checkpoint of OmniBooth is released in huggingface. If you want to use our model to run inference. Please put them at ./OmniBooth/ckp/.

Train

bash train.sh

The details of the script are as follows:

export MODEL_DIR="./ckp/stable-diffusion-xl-base-1.0"
export VAE_DIR="./ckp/sdxl-vae-fp16-fix"

export EXP_NAME="omnibooth_train"
export OUTPUT_DIR="./ckp/$EXP_NAME"

accelerate launch --gpu_ids 0,  --num_processes 1  --main_process_port 3226  train.py \
      --pretrained_model_name_or_path=$MODEL_DIR \
      --pretrained_vae_model_name_or_path=$VAE_DIR \
      --output_dir=$OUTPUT_DIR \
      --width=1024 \
      --height=1024 \
      --patch_size=364 \
      --learning_rate=4e-5 \
      --num_train_epochs=12 \
      --train_batch_size=1 \
      --mulscale_batch_size=2 \
      --mixed_precision="fp16" \
      --num_validation_images=2 \
      --validation_steps=500 \
      --checkpointing_steps=5000 \
      --checkpoints_total_limit=10 \
      --ctrl_channel=1024 \
      --use_sdxl=True \
      --enable_xformers_memory_efficient_attention \
      --report_to='wandb' \
      --resume_from_checkpoint="latest" \
      --tracker_project_name="omnibooth-demo" 

The training process will take 3 days to complete using 8 NVIDIA A100. We use batchsize=2, image height set as 1024, image width follows the ground-truth image ratio. It will take 65GB memory for each GPU.

Inference

bash infer.sh

You will find generated images at ./vis_dir/. The image is shown as follows: image

Behavior analysis

  1. The text instruction is not perfect, it is applicable to descriptions of attributes like color and texture, but similar to prior work, it is challenging to provide more granular and fine-grained descriptions. Scaling the data and model can help with this problem.
  2. The image instruction may result in generated images with washed-out colors, possibly due to brightness augmentation. This can be adjusted by editing the global prompt: ‘a brighter image’.
  3. Video Dataset. Ideally, we should use video datasets to train image-instructed generation, similar to Anydoor. However, in our multi-modal setting, the cost of obtaining video datasets + tracking annotations + panoptic annotations is relatively high, so we only trained our model on the single-view COCO dataset. If you plan to expand the training data to video datasets, please let me know.

Instance data structure

I provide several instance mask datasets for inference in data/instance_dataset. This data is converted from coco dataset. The data structure is as follows:

# use data/instance_dataset/plane as an example
0000_mask.png
0000.png
0001_mask.png
0001.png
0002_mask.png
0002.png
...
prompt_dict.json

The mask file is a binary mask that indicate the instance location. The image file is the optional image reference. Turn the --text_or_img=img to use it.

The prompt_dict.json is a dictionary contains instance prompt and global_prompt. The prompt is a string that describes the instance or global image. For example, "prompt_dict.json" is as follows:

{
   "prompt_0": "a plane is silhouetted against a cloudy sky", 
   "prompt_1": "a road", 
   "prompt_2": "a pavement of merged", 
   "global_prompt": "large mustard yellow commercial airplane parked in the airport"
}

Acknowledgment

Additionally, we express our gratitude to the authors of the following opensource projects:

BibTeX

@article{li2024omnibooth,
  title={OmniBooth: Learning Latent Control for Image Synthesis with Multi-modal Instruction},
  author={Li, Leheng and Qiu, Weichao and Yan, Xu and He, Jing and Zhou, Kaiqiang and Cai, Yingjie and Lian, Qing and Liu, Bingbing and Chen, Ying-Cong},
  journal={arXiv preprint arXiv:2410.04932},
  year={2024}
}

This project is licensed under the MIT License - see the LICENSE file for details.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published