Skip to content

Latest commit

 

History

History
779 lines (598 loc) · 24.2 KB

README.md

File metadata and controls

779 lines (598 loc) · 24.2 KB

Snuffy: Efficient Whole Slide Image Classifier

Static Badge PWC PWC

Hossein Jafarinia, Alireza Alipanah, Danial Hamdi, Saeed Razavi, Nahal Mirzaie, Mohammad Hossein Rohban

[arXiv] [Project Page] [Demo] [BibTex]

PyTorch implementation for the Multiple Instance Learning framework described in the paper Snuffy: Efficient Whole Slide Image Classifier (ECCV 2024, accepted).



Snuffy is a novel MIL-pooling method based on sparse transformers, designed to address the computational challenges in Whole Slide Image (WSI) classification for digital pathology. Our approach mitigates performance loss with limited pre-training and enables continual few-shot pre-training as a competitive option.

Key features:

  • Tailored sparsity pattern for pathology
  • Theoretically proven universal approximator with tight probabilistic sharp bounds
  • Superior WSI and patch-level accuracies on CAMELYON16 and TCGA Lung cancer datasets

Overview

This repository provides a complete, runnable implementation of the Snuffy framework, including code for the FROC metric, which is unique among WSI classification frameworks to the best of our knowledge.

  1. Slide Patching: WSIs are divided into manageable patches.
  2. Self-Supervised Learning: An SSL method is trained on the patches to create an embedder.
  3. Feature Extraction: The embedder computes features (embeddings) for each slide.
  4. MIL Training: The Snuffy MIL framework is applied to the computed features.

Each step in this pipeline can be executed independently, with intermediate results available for download to facilitate continued processing.

Table of Contents
  1. Requirements
  2. Dataset Download
  3. Train/Val/Test Split
  4. Slide Preparation: Patching and N-Shot Dataset Creation
  5. Training the Embedder
  6. Feature Extraction
  7. MIL Training
  8. Visualization
  9. Acknowledgement
  10. Citation

Requirements

System Requirements

  • Operating System: Ubuntu 20.04 LTS (or compatible Linux distribution)
  • Python Version: 3.8 or later
  • GPU: Recommended for faster processing (CUDA-compatible)

Notes

  • Disk Space: Ensure you have sufficient disk space for dataset downloads and processing, especially if you intend to work with raw slides rather than pre-computed embeddings. Raw slide data can be very large.
  • Hardware: The MIL training code can run on both GPU and CPU. For optimal performance, a GPU is strongly recommended.

Downloading and Preparing Datasets

  1. Amazon CLI: To download the CAMELYON16 dataset's raw whole-slide images, you'll need the AWS CLI. Install it by:
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
unzip awscliv2.zip
./aws/install
  1. GDC Client (For downloading the TCGA dataset): This is automatically downloaded and installed when you use the download_tcga_lung.sh script.

  2. OpenSlide is necessary if you intend to patch the slides yourself using the deepzoom_tiler_camelyon16.py or deepzoom_tiler_tcga_lung_cancer.py scripts. Install OpenSlide with:

# Update package list and install OpenSlide
apt-get update
apt-get install openslide-tools

Running Snuffy

  1. The ASAP package is required for calculating the FROC metric. Install ASAP and its multiresolutionimageinterface Python package as follows:
# Download and install ASAP
wget https://github.com/computationalpathologygroup/ASAP/releases/download/ASAP-2.1/ASAP-2.1-py38-Ubuntu2004.deb
apt-get install -f "./ASAP-2.1-py38-Ubuntu2004.deb"
  1. Required Python packages can be installed with:
# Install Python packages from requirements.txt
pip install -r requirements.txt

Note: The requirements.txt file includes specific package versions used and verified in our experiments. However, newer versions available in your environment may also be compatible.

Additional Components

  1. MAE with Adapter: Refer to the MAE repository for installation instructions.

    Important: If using PyTorch versions 1.8+ , follow the instructions in the MAE repository to fix compatibility issue with the timm module. Alternatively, run the following script to fix the issue.

    chmod +x requirements_timm_patch.sh
    ./requirements_timm_patch.sh

    Note that we've also included a modified version of timm, to support adapter functionality.

Download Data

CAMELYON16

  1. List and Download Dataset: Run the following commands to list and download the CAMELYON16 dataset:

    aws s3 ls --no-sign-request s3://camelyon-dataset/CAMELYON16/ --recursive
    aws s3 cp --no-sign-request s3://camelyon-dataset/CAMELYON16/ raw_data/camelyon16 --recursive
  2. Directory Structure: After downloading, your raw_data/camelyon16 directory should look like this:

    -- camelyon16
        |-- README.md
        |-- annotations
        |-- background_tissue
        |-- checksums.md5
        |-- evaluation
        |-- images
        |-- license.txt
        |-- masks
        `-- pathology-tissue-background-segmentation.json
  3. Organize Files:
    Use the provided script to copy the necessary files into the datasets/camelyon16 directory. If space is limited, modify the script to move files instead of copying them.

    python move_camelyon16_tifs.py
  4. Final Directory Structure:

    datasets/camelyon16
    |-- annotations
    |   |-- test_001.xml
    |   |-- tumor_001.xml
    |   |-- ...
    |-- masks
    |   |-- normal_001_mask.tif
    |   |-- test_001_mask.tif
    |   |-- tumor_001_mask.tif
    |   |-- ...
    |-- 0_normal
    |   |-- normal_004.tif
    |   |-- test_018.tif
    |   |-- ...
    |-- 1_tumor
    |   |-- test_046.tif
    |   |-- tumor_075.tif
    |   |-- ...
    |-- reference.csv
    |-- n_shot_dataset_maker.py
    |-- train_validation_test_reverse_camelyon.py
    `-- train_validation_test_splitter_camelyon.py

TCGA Lung Cancer

To download the TCGA Lung Cancer dataset, run the following script. This will download the slides listed in the LUAD manifest and LUSC manifest to the datasets/tcga/{luad, lusc} directory. Each slide will be stored in its own directory, named according to its ID in the manifest.

chmod +x download_dataset.sh
./download_tcga_lung.sh

MIL datasets

Download the MIL datasets (sourced from the DSMIL project) and unzip them into the datasets/ directory.

wget https://uwmadison.box.com/shared/static/arvv7f1k8c2m8e2hugqltxgt9zbbpbh2.zip
unzip mil-dataset.zip -d datasets/

Slide Preparation: Patching

CAMELYON16

This script processes TIFF slides located in datasets/camelyon16/{0_normal, 1_tumor}/. For each slide, it creates a directory at datasets/camelyon16/single/{0_normal, 1_tumor}/{slide_name}, saving the extracted patches as JPEG images.

python deepzoom_tiler_camelyon16.py

TCGA Lung Cancer

This script processes SVS slides in datasets/tcga/{lusc, luad}/ and saves the extracted patches in datasets/tcga/single/{lusc, luad}/{slide_name} as JPEG images.

python deepzoom_tiler_tcga_lung_cancer.py

For both scripts, please refer to their arguments for detailed information on the script's arguments and their functionalities.

Train/Val/Test Split and N-Shot Dataset Creation

CAMELYON16

To split the CAMELYON16 dataset:

cd datasets/camelyon16
python train_validation_test_splitter_camelyon.py

This script reorganizes the directory structure from:

datasets/camelyon16/single/{0_normal, 1_tumor}

to:

datasets/camelyon16/single/fold1/{train, validation, test}/{0_normal, 1_tumor}

The official CAMELYON16 test set is used for testing, while the remaining data is randomly split into training and validation sets with an 80/20 ratio. You can adjust the fold number directly in the script.

To reverse the CAMELYON16 split:

cd datasets/camelyon16
python train_validation_test_reverse_camelyon.py

The processed and shuffled datasets are saved with filenames that reflect the dataset name, fold count, and split ratio.

TCGA Lung Cancer

K-Fold Cross Validation Split

The fold_generator.py script creates K-Fold cross-validation splits for the TCGA data, ensuring that a single patient's slides are not divided across multiple splits. It uses the patients.csv reference file and stores the fold information in datasets/tcga/folds/fold_{i}.csv.

To run the K-Fold split:

cd datasets/tcga
python fold_generator.py

Selecting a Fold

After generating folds, use the train_validation_test_splitter_tcga.py script to organize the directories according to a selected fold:

python train_validation_test_splitter_tcga.py

This script reorganizes the directory structure from:

datasets/tcga/single/{0_luad, 1_lusc}

to:

datasets/tcga/single/fold{i}/{train, validation, test}/{0_luad, 1_lusc}

De-selecting a Fold

To reverse the TCGA split and restore the original directory structure:

cd datasets/tcga
python train_validation_test_reverse_tcga.py

MIL Datasets

The mil_cross_validation.py script loads and processes MIL datasets downloaded in the previous step (Musk1, Musk2, Elephant) into a format compatible with Snuffy. It then performs cross-validation, ensuring each fold contains both negative and positive bags.

cd datasets/mil_dataset
# python mil_cross_validation.py --dataset [Musk1, Musk2, Elephant] --num_folds [10] --train_valid_ratio [0.2]
python mil_cross_validation.py --dataset Musk1

N-Shot Patch Dataset Creation

CAMELYON16

To create a 50-Shot patch dataset (a dataset containing at most n patches of each WSI):

cd datasets/camelyon16
python n_shot_dataset_maker.py --shots=50

This will create a new folder named single/fold1_50shot based on the dataset in single/fold1. In this new folder, each slide will have at most 50 patches (or all patches if the original number is less than 50).

TCGA

cd datasets/tcga
python n_shot_dataset_maker_tcga.py --shots 5

Training the Embedder

Method Instructions Embedder Weights Embeddings
SimCLR (From Scratch) Refer to DSMIL Weights Embeddings
DINO (From Scratch) Refer to DINO (And use a ViT-S/16) Weights Embeddings
DINO (with Adapter) Refer to DINO with Adapter Section Weights Embeddings
MAE (with Adapter) Refer to MAE with Adapter Section Weights Embeddings

DINO with Adapter

Download DINO ImageNet-1K Pretrained ViT-S8 full wights:

wget https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_full_checkpoint.pth

Continue pretraining with DINO Adapter:

python dino_adapter/main_dino_adapter.py \
  --adapter_ffn_scalar=10 \
  --arch=vit_small \
  --batch_size_per_gpu=16 \
  --clip_grad=3 \
  --data_path_train=datasets/camelyon16/single/fold1_50shot/train \
  --data_path_valid=datasets/camelyon16/single/fold1_50shot/validation \
  --epochs=100 \
  --ffn_num=32 \
  --freeze_last_layer=0 \
  --full_checkpoint=dino_deitsmall8_pretrain_full_checkpoint.pth \
  --lr__warmup_epochs__minlr="[0.0005, 10, 1e-06]" \
  --momentum_teacher=0.9995 \
  --norm_last_layer=False \
  --output_dir=out \
  --patch_size=8 \
  --random_head=1 \
  --teacher_temp__warmup_teacher_temp_epochs="[0.04, 0]" \
  --warmup_teacher_temp=0.04 \
  --weight_decay__weight_decay_end="[0.04, 0.4]"

MAE with Adapter

Download MAE ImageNet-1K Pretrained ViT-S8 full wights:

wget https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base_full.pth

Continue pretraining with MAE Adapter:

torchrun main_pretrain_adapter.py \
--accum_iter=1 \
--adapter_ffn_scalar=1 \
--blr__min_lr__warmup_epochs="[0.001, 0, 40]" \
--data_path=datasets/camelyon16/single/fold1_200shot \
--epochs=400 \
--full_checkpoint=mae_pretrain_vit_base_full.pth \
--norm_pix_loss=0 \
--train_linears__linears_from_scratch="[1, 1]"

Feature Extraction

The compute_feats.py script extracts features (embeddings) from a dataset using a specified embedder model. It processes the dataset and saves the cleaned embedder weights, feature vectors, and corresponding labels.

Input Dataset Structure

The dataset is expected to follow this directory structure:

datasets/
└── {dataset_name}/
    ├── single/
    │   └── {fold}/
    │       ├── train/
    │       ├── validation/
    │       └── test/
    └── tile_label.csv
  • {dataset_name}: The name of your dataset.
  • {fold}: The specific fold of data (e.g., fold1, fold2, ...).
  • train/, validation/, test/: Directories containing the patches for training, validation, and testing, respectively.
  • tile_label.csv: CSV file containing the labels for the patches, if available, created by deepzoom_tiler.

Output Directory Structure

The script saves the outputs in the following directory structure:

embeddings/
└── {embedder}_{version_name}/
    └── {dataset_name}/
        ├── embedder.pth
        ├── {train, test, validation}/
        │   └── {0_normal, 1_tumor}.csv
        │   ├── {0_normal, 1_tumor}/
        │   │   └── {slide_name}.csv
        └── {dataset_name}.csv
  • {embedder}: The name of the embedder model used (e.g., SimCLR).
  • {version_name}: The version name of the embedder model.
  • {dataset_name}: The name of the dataset.
  • embedder.pth: The cleaned embedder weights.
  • {slide_name}.csv: CSV file containing features [feature_0, ..., feature_511, position, label] for each slide. Each row corresponds to a patch from the slide.
  • {split}/{class_name}.csv: CSV file containing [bag_path, bag_label] for each class in each split ( train/validation/test).
  • {dataset_name}.csv: CSV file containing [bag_path, bag_label] for the whole dataset.

Usage on CAMELYON16

SimCLR from scratch

python compute_feats.py \
  --backbone=resnet18 \
  --norm_layer=instance \
  --weights=embedders/dsmil_simclr.pth \
  --embedder=SimCLR \
  --version_name=dsmil_simclr

DINO from scratch

python compute_feats.py \
  --embedder=DINO \
  --num_classes=2048 \
  --backbone=vit_small \
  --weights=embedders/dino_scratch.pth \
  --version_name=dino_scratch

DINO with Adapter

python compute_feats.py \
  --embedder=DINO \
  --num_classes=2048 \
  --backbone=vit_small \
  --patch_size=8 \
  --weights=embedders/dino_adapter.pth \
  --ffn_num=32 \
  --adapter_ffn_scalar=10 \
  --version_name=dino_adapter \
  --use_adapter \
  --transform 1

MAE with Adapter

python compute_feats.py \
  --embedder=MAE \
  --num_classes=512 \
  --backbone=mae_vit_base_patch16 \
  --weights=embedders/mae_adapter.pth \
  --ffn_num=64 \
  --adapter_ffn_scalar=1 \
  --version_name=mae_adapter \
  --use_adapter \
  --transform 1

Usage on TCGA Lung

SimCLR from scratch

python compute_feats.py \
  --backbone=resnet18 \
  --dataset=tcga \
  --norm_layer=instance \
  --weights=embedders/dsmil_simclr_tcga.pth \
  --embedder=SimCLR \
  --version_name=dsmil_simclr

MIL Training

Example Run for CAMELYON16

DINO from scratch

python train.py \ 
  --activation=relu \
  --arch=snuffy \
  --betas="[0.9, 0.999]" \
  --big_lambda=900 \
  --dataset=camelyon16 \
  --embedding=DINO_dino_scratch \
  --encoder_dropout=0.1 \
  --feats_size=384 \
  --l2normed_embeddings=1 \
  --lr=0.02 \
  --num_epochs=200 \
  --num_heads=4 \
 --optimizer=adamw \
 --random_patch_share=0.7777777777777778 \
 --scheduler=cosine \
 --single_weight__lr_multiplier=1 \
 --soft_average=0 \
 --weight_decay=0.05 \
 --weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"

DINO with Adapter

python train.py \
  --activation=relu \
  --arch=snuffy \
  --betas="[0.9, 0.999]" \
  --big_lambda=500 \
  --dataset=camelyon16 \
  --embedding=DINO_dino_adapter \
  --encoder_dropout=0.1 \
  --feats_size=384 \
  --l2normed_embeddings=1 \
  --lr=0.02 \
  --num_epochs=200 \
  --num_heads=4 \
  --optimizer=adamw \
  --random_patch_share=0.5 \
  --scheduler=cosine \
  --single_weight__lr_multiplier=1 \
  --soft_average=1 \
  --weight_decay=0.05 \
  --weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"

MAE with Adapter

python train.py \
  --activation=relu \
  --arch=snuffy \
  --betas="[0.9, 0.999]" \
  --big_lambda=500 \
  --dataset=camelyon16 \
  --embedding=MAE_mae_adapter \
  --encoder_dropout=0 \
  --feats_size=768 \
  --l2normed_embeddings=0 \
  --lr=0.02 \
  --num_epochs=200 \
  --num_heads=4 \
  --optimizer=adamw \
  --random_patch_share=0.5 \
  --scheduler=cosine \
  --single_weight__lr_multiplier=1 \
  --soft_average=1 \
  --weight_decay=0.05 \
  --weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"

--feats_size should match the size of features you got in Feature Extraction. --random_patch_share * --big_lambda shows the number of random patches and the rest are top patches.

For TCGA use --arch=snuffy_multiclass.

Example Run for MIL Datasets

python train.py \
  --arch=snuffy \
  --dataset=musk1 \
  --num_heads=2 \
  --cv_num_folds 10 \
  --cv_valid_ratio 0.2 \
  --cv_current_fold 1

Notes:

  1. Feature Size is automatically set based on the dataset ('musk1' and 'musk2': 166, 'elephant': 230). No manual adjustment needed.
  2. MultiHeadAttention: Ensure the feature size is divisible by the number of heads.
  3. Cross-Validation: Use mil_cross_validation.py to generate a shuffle file ({dataset_file_name}_{num_folds}folds_{valid_ratio}split.pkl, e.g. musk1_10folds_0.2split.pkl). Match args.cv_num_folds and args.cv_valid_ratio in this script to read the file correctly. Set the desired fold to train using args.cv_current_fold.

Visualization

In the figure below, the black line outlines the tumor area. The model's attention is represented by a color overlay, where red indicates the highest attention and blue indicates the lowest. As shown, the model effectively highlights the tumor regions.

To create heatmaps similar to the one shown above, run the following command:

python roi.py \
  --batch_size=512 \
  --num_workers=24 \
  --embedder_weights=embedders/clean/camelyon16/SimCLR/embedder.pth \
  --aggregator_weights=aggregators/snuffy_simclr_dsmil.pth \
  --thres_tumor=0.75959325 \
  --num_heads=2 \
  --encoder_dropout=0.2 \
  --k=900 \
  --random_patch_share=0.7777777777777778 \
  --activation=gelu \
  --depth=5

The script requires the following inputs:

  • --embedder_weights: Path to the embedder weights file
  • --aggregator_weights: Path to the aggregator weights file
  • Ground truth masks located in datasets/camelyon16/masks/
  • Raw TIFF slides located in datasets/camelyon16/1_tumor/
  • Name and label of slides located in datasets/camelyon16/reference.csv

For each slide, the script generates the following outputs:

  • Heatmaps saved in roi_output/{slide_name}/cmaps/, where:
    • jet_slide.png is the raw slide.
    • jet.png is the slide with the attention map overlay and the ground truth tumor region outlined in black.

By default, the script processes 3 slides from the CAMELYON16 test set, but you can customize the slides to process by modifying the script. Additionally, reducing the DPI setting can speed up processing.

You can download the aggregator used for creating the figure above from here.

Acknowledgments

This codebase is built upon the work of DSMIL, DINO, and MAE. We extend our gratitude to the authors for their valuable contributions.

Citation

If you find our work helpful for your research, please consider giving a star to this repository and citing the following BibTeX entry.

@misc{jafarinia2024snuffyefficientslideimage,
      title={Snuffy: Efficient Whole Slide Image Classifier}, 
      author={Hossein Jafarinia and Alireza Alipanah and Danial Hamdi and Saeed Razavi and Nahal Mirzaie and Mohammad Hossein Rohban},
      year={2024},
      eprint={2408.08258},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2408.08258}, 
}