PyTorch implementation for "Learning to Rematch Mismatched Pairs for Robust Cross-Modal Retrieval (CVPR 2024)"
- Python 3.8
- torch 1.12
- numpy
- scikit-learn
- pomegranate Install. Note that pomegranate requires
Cython=0.29
,NumPy
,SciPy
,NetworkX
, andjoblib
. Then you can runpython setup.py build
andpython setup.py install
to install it.) - Punkt Sentence Tokenizer:
import nltk
nltk.download()
> d punkt
We follow NCR to obtain image features and vocabularies.
We use the same noise index settings as DECL and RCL, which could be found in noise_index
. The mismatching ratio (noise ratio) is set as 0.2, 0.4, 0.6, and 0.8.
Modify some necessary parameters and run it.
For Flickr30K:
sh train_f30k.sh
For MS-COCO:
sh train_coco.sh
For CC152K:
sh train_cc152k.sh
Modify some necessary parameters and run it.
python main_testing.py
The pre-trained models are available here.
The code is based on SCAN, SGRAF, NCR, DECL, and KPG-RL licensed under Apache 2.0.