arXiv (coming soon)
Official PyTorch implementation and pre-trained models for GraFT: Gradual Fusion Transformer for Multimodal Re-Identification
We introduce the Gradual Fusion Transformer (GraFT), a cutting-edge model tailored for Multimodal Object Re-Identification (ReID). Traditional ReID models exhibit scalability constraints when handling multiple modalities due to their heavy reliance on late fusion, delaying the merging of insights from various modalities. GraFT tackles this by utilizing learnable fusion tokens which guide self-attention across encoders, adeptly capturing nuances of both modality-specific and object-centric features. Complementing its core design, GraFT is bolstered with an innovative training paradigm and an augmented triplet loss, refining the feature embedding space for ReID tasks. Our extensive ablation studies empiricaly validate our architectural design choices, proving GraFT's consistent outperformance against prevailing multimodal ReID benchmarks.
We used the RGBNT100 and RGBN300 datasets to benchmark against other algorithms. You may see our results in the following table:
Method | RGBNT100 | RGBN300 | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
Params | mAP | R1 | R5 | R10 | Params | mAP | R1 | R5 | R10 | |
HAMNet | 78M | 65.4 | 85.5 | 87.9 | 88.8 | 52M | 61.9 | 84.0 | 86.0 | 87.0 |
DANet | 78M | N/A | N/A | N/A | N/A | 52M | 71.0 | 89.9 | 90.9 | 91.5 |
GAFNet | 130M | 74.4 | 93.4 | 94.5 | 95.0 | 130M | 72.7 | 91.9 | 93.6 | 94.2 |
Multi-Stream ViT | 274M | 74.6 | 91.3 | 92.8 | 93.5 | 187M | 73.7 | 91.9 | 94.1 | 94.8 |
GraFT (Ours) | 101M | 76.6 | 94.3 | 95.3 | 96.0 | 97M | 75.1 | 92.1 | 94.5 | 95.2 |
- Release Pre-trained models
-
Clone from the correct branch
git clone <see HTTPS or SSH option>
-
Create python venv
python -m venv venv
-
Activate venv
source venv/bin/activate
-
Install requirements
pip install -r requirements.txt
-
Install hugginface transformers from source to use DeiT
pip install git+https://github.com/huggingface/transformers
-
Check train_optuna.sh to run with correct configs/yaml file
sh ./train_optuna.sh
- Check configs
- Check paths: output_dir, dataroot, dataset, ckpt_dir
use_optuna: True
gpus: [0, 1, 2, 3]
model_name: "transformer_baseline_reid_v2”
# Weights and Biases Logging
use_wandb: True
wandb_project: "mm-mafia-reid-baseline" # generally stays same
study_name: "transformer_baseline_v2_rn100_param=5m_no_downsampling_patch=24" # experiment level
wandb_run_name: "transformer_baseline_v2_rn100_param=5m_no_downsampling_patch=24" # keep same as study_name
wandb_trial_name: "elarger patch size=64, lower seq_len=4, e-5-6lr, transformer_encoder=3" # trial_name under study
-
If using optuna for hyperparameter search:
- use_optuna, if True then make sure to use train_optuna.py
- check the lr and weight decay range specified at the beginning of train_optuna.py
- use_optuna, if True then make sure to use train_optuna.py
-
Activate virtual environment
source venv/bin/activate
-
Run job scheduler interface (optional)
python webapp_ui/app.py