Skip to content
/ GraFT Public

GraFT: Gradual Fusion Transformer for Multimodal Re-Identification

Notifications You must be signed in to change notification settings

Nano1337/GraFT

Repository files navigation

GraFT: Gradual Fusion Transformer for Multimodal Re-Identification

arXiv (coming soon)

Official PyTorch implementation and pre-trained models for GraFT: Gradual Fusion Transformer for Multimodal Re-Identification

We introduce the Gradual Fusion Transformer (GraFT), a cutting-edge model tailored for Multimodal Object Re-Identification (ReID). Traditional ReID models exhibit scalability constraints when handling multiple modalities due to their heavy reliance on late fusion, delaying the merging of insights from various modalities. GraFT tackles this by utilizing learnable fusion tokens which guide self-attention across encoders, adeptly capturing nuances of both modality-specific and object-centric features. Complementing its core design, GraFT is bolstered with an innovative training paradigm and an augmented triplet loss, refining the feature embedding space for ReID tasks. Our extensive ablation studies empiricaly validate our architectural design choices, proving GraFT's consistent outperformance against prevailing multimodal ReID benchmarks.

Fig2 Final Image

Datasets and Results

We used the RGBNT100 and RGBN300 datasets to benchmark against other algorithms. You may see our results in the following table:

Method RGBNT100 RGBN300
Params mAP R1 R5 R10 Params mAP R1 R5 R10
HAMNet 78M 65.4 85.5 87.9 88.8 52M 61.9 84.0 86.0 87.0
DANet 78M N/A N/A N/A N/A 52M 71.0 89.9 90.9 91.5
GAFNet 130M 74.4 93.4 94.5 95.0 130M 72.7 91.9 93.6 94.2
Multi-Stream ViT 274M 74.6 91.3 92.8 93.5 187M 73.7 91.9 94.1 94.8
GraFT (Ours) 101M 76.6 94.3 95.3 96.0 97M 75.1 92.1 94.5 95.2
Pareto Image

Catalog

  • Release Pre-trained models

Setup

  1. Clone from the correct branch

    git clone <see HTTPS or SSH option>
  2. Create python venv

    python -m venv venv
  3. Activate venv

    source venv/bin/activate
  4. Install requirements

    pip install -r requirements.txt
  5. Install hugginface transformers from source to use DeiT

    pip install git+https://github.com/huggingface/transformers
  6. Check train_optuna.sh to run with correct configs/yaml file

    sh ./train_optuna.sh

Pre-Experiment Checklist

  • Check configs
  • Check paths: output_dir, dataroot, dataset, ckpt_dir
use_optuna: True

gpus: [0, 1, 2, 3]

model_name: "transformer_baseline_reid_v2”

# Weights and Biases Logging
use_wandb: True
wandb_project: "mm-mafia-reid-baseline" # generally stays same
study_name: "transformer_baseline_v2_rn100_param=5m_no_downsampling_patch=24" # experiment level
wandb_run_name: "transformer_baseline_v2_rn100_param=5m_no_downsampling_patch=24" # keep same as study_name
wandb_trial_name: "elarger patch size=64, lower seq_len=4, e-5-6lr, transformer_encoder=3" # trial_name under study
  • If using optuna for hyperparameter search:

    • use_optuna, if True then make sure to use train_optuna.py
      • check the lr and weight decay range specified at the beginning of train_optuna.py
  • Activate virtual environment

    source venv/bin/activate
  • Run job scheduler interface (optional)

    python webapp_ui/app.py

About

GraFT: Gradual Fusion Transformer for Multimodal Re-Identification

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published