Ambient Diffusion Posterior Sampling: Solving Inverse Problems with Diffusion Models trained on Corrupted Data
This repository hosts the official PyTorch implementation of the paper: Ambient Diffusion Posterior Sampling: Solving Inverse Problems with Diffusion Models trained on Corrupted Data.
Authored by: Asad Aali, Giannis Daras, Brett Levac, Sidharth Kumar, Alexandros G. Dimakis, Jonathan I. Tamir
Figure : Prior samples from Ambient Diffusion trained with under-sampled data at R = 2, 4, 6, 8 (columns 1 - 4), EDM trained with L1-wavelet reconstructions of subsampled data at R = 2, 4, 6, 8 (columns 5 - 8), NCSNV2 trained with fully sampled data (column 9) and EDM trained with fully sampled data (column 10)
We provide a framework for solving inverse problems with diffusion models learned from linearly corrupted data. Our method, Ambient Diffusion Posterior Sampling (A-DPS), leverages a generative model pre-trained on one type of corruption (e.g. image inpainting) to perform posterior sampling conditioned on measurements from a potentially different forward process (e.g. image blurring). We test the efficacy of our approach on standard natural image datasets (CelebA, FFHQ, and AFHQ) and we show that A-DPS can sometimes outperform models trained on clean data for several image restoration tasks in both speed and performance. We further extend the Ambient Diffusion framework to train MRI models with access only to Fourier subsampled multi-coil MRI measurements at various acceleration factors (R = 2, 4, 6, 8). We again observe that models trained on highly subsampled data are better priors for solving inverse problems in the high acceleration regime than models trained on fully sampled data.
The recommended way to run the code is with an Anaconda/Miniconda environment. First, clone the repository:
git clone https://github.com/utcsilab/ambient-diffusion-mri.git
.
Then, create a new Anaconda environment and install the dependencies:
conda env create -f environment.yml -n ambient
You will also need to have diffusers
installed from the source. To do so, run:
pip install git+https://github.com/huggingface/diffusers.git
From our experiments, we share nine (9) pre-trained models:
- A supervised EDM model trained on fully sampled (R = 1) FastMRI data.
- Four Ambient Diffusion models trained on undersampled FastMRI data at acceleration rates of R = 2, 4, 6, 8.
- Four EDM models trained after L1-Wavelet compressed sensing reconstruction of the training set at acceleration rates of R = 2, 4, 6, 8.
The checkpoints are available here. To download from the terminal, simply run:
wget -v -O ambient_models.zip -L https://utexas.box.com/shared/static/axofnwib9kukdpa92ge4ays87dmuvpf7.zip
For the experiments, we used a pre-processed version of NYU's FastMRI dataset.
To set up the dataset for training/inference, follow the instructions provided here.
To train a new Ambient Diffusion model on the FastMRI dataset, run the following bash script:
ambient-diffusion-mri/train.sh
R=4
EXPERIMENT_NAME=brainMRI_R=$R
GPUS_PER_NODE=1
GPU=0
DATA_PATH=/path_to_dataset/numpy/ksp_brainMRI_384.zip
OUTDIR=/path_to_output/models/$EXPERIMENT_NAME
CORR=$R
DELTA=5
BATCH=8
METHOD=ambient
torchrun --standalone --nproc_per_node=$GPUS_PER_NODE \
train.py --gpu=$GPU --outdir=$OUTDIR --experiment_name=$EXPERIMENT_NAME \
--dump=200 --cond=0 --arch=ddpmpp \
--precond=$METHOD --cres=1,1,1,1 --lr=2e-4 --dropout=0.1 --augment=0 \
--data=$DATA_PATH --norm=2 --max_grad_norm=1.0 --mask_full_rgb=True \
--corruption_probability=$CORR --delta_probability=$DELTA --batch=$BATCH \
--normalize=False --fp16=True --wandb_id=$EXPERIMENT_NAME
To generate images from the trained model, run the following bash script:
ambient-diffusion-mri/prior.sh
:
R=4
EXPERIMENT_NAME=brainMRI_prior_R=$R
GPUS_PER_NODE=1
GPU=0
MODEL_PATH=/path_to_model/models/brainMRI_R=$R
MAPS_PATH=/path_to_dataset/numpy/ksp_brainMRI_384.zip
SEEDS=1000
BATCH=8
torchrun --standalone --nproc_per_node=$GPUS_PER_NODE \
prior.py --gpu=$GPU --network=$MODEL_PATH/ --maps_path=$MAPS_PATH\
--outdir=results/$EXPERIMENT_NAME \
--experiment_name=$EXPERIMENT_NAME \
--ref=$MODEL_PATH/stats.jsonl \
--seeds=$SEEDS --batch=$BATCH \
--mask_full_rgb=True --num_masks=1 --guidance_scale=0.0 \
--training_options_loc=$MODEL_PATH/training_options.json \
--num=$SEEDS --img_channels=2 --with_wandb=False
This will generate 1000 images in the folder <results/$EXPERIMENT_NAME>
.
To generate posterior samples given a trained model, run the following bash script:
ambient-diffusion-mri/solve_inverse_adps.sh
:
TRAINING_R=4
EXPERIMENT_NAME=brainMRI_ambientDPS
GPUS_PER_NODE=1
GPU=0
MODEL_PATH=/path_to_model/models/brainMRI_R=$TRAINING_R
MEAS_PATH=/path_to_measurements
STEPS=500
METHOD=ambient
for seed in 15
do
for R in 2 4 6 8
do
for sample in {0..100}
do
torchrun --standalone --nproc_per_node=$GPUS_PER_NODE \
solve_inverse_adps.py --seed $seed --latent_seeds $seed --gpu $GPU \
--sample $sample --inference_R $R --training_R $TRAINING_R \
--l_ss 1 --num_steps $STEPS --S_churn 0 \
--measurements_path $MEAS_PATH --network $MODEL_PATH \
--outdir results/$EXPERIMENT_NAME --img_channels 2 --method $METHOD
done
done
done
To generate posterior samples given a trained model, run the following bash script:
ambient-diffusion-mri/solve_inverse_1step.sh
:
R=4
EXPERIMENT_NAME=brainMRI_1step_R=$R
GPUS_PER_NODE=1
GPU=0
MODEL_PATH=/path_to_model/models/brainMRI_R=$R
MEAS_PATH=/path_to_measurements
SEEDS=100
torchrun --standalone --nproc_per_node=$GPUS_PER_NODE \
solve_inverse_1step.py --gpu=$GPU --network=$MODEL_PATH/ \
--outdir=results/$EXPERIMENT_NAME \
--experiment_name=$EXPERIMENT_NAME \
--ref=$MODEL_PATH/stats.jsonl \
--seeds=$SEEDS --batch=1 \
--mask_full_rgb=True --training_options_loc=$MODEL_PATH/training_options.json \
--measurements_path=$MEAS_PATH --num=2 --img_channels=2 --with_wandb=False
The following script was used for calculating the FID scores:
ambient-diffusion-mri/fid.sh
:
Example:
python fid.py ref --data=path_to_ref_data --dest=path_to_ref_scores.npz
torchrun --standalone --nproc_per_node=1 fid.py calc --images=path_to_priors --ref=path_to_ref_scores.npz