Skip to content

Source code of our TCSVT'22 paper Reading-strategy Inspired Visual Representation Learning for Text-to-Video Retrieval

License

Notifications You must be signed in to change notification settings

LiJiaBei-7/rivrl

Repository files navigation

Reading-strategy Inspired Visual Representation Learning for Text-to-Video Retrieval

Source code of our paper Reading-strategy Inspired Visual Representation Learning for Text-to-Video Retrieval.

image

Table of Contents

Environments

  • CUDA 10.1

  • Python 3.8.5

  • PyTorch 1.7.0

We used Anaconda to setup a deep learning workspace that supports PyTorch. Run the following script to install the required packages.

conda create --name rivrl_env python=3.8.5
conda activate rivrl_env
git clone https://github.com/LiJiaBei-7/rivrl.git
cd rivrl
pip install -r requirements.txt
conda deactivate

Required Data

We use three public datasets: MSR-VTT, VATEX, and TGIF. Please refer to here for detailed description and how to download the datasets. Since our model uses additional Bert features, you shall download the pre-extracted Bert features from Baidu pan (url, password:6knd). You can also run the following script to download the features of BERT, the extracted data is placed in $HOME/VisualSearch/.

ROOTPATH=$HOME/VisualSearch
mkdir -p $ROOTPATH && cd $ROOTPATH
mkdir bert_extract && cd bert_extract

# download the features of BERT
wget http://8.210.46.84:8787/rivrl/bert/<bert-Name>.tar.gz
tar zxf <bert-Name>.tar.gz -C $ROOTPATH
# <bert-Name> is msrvtt_bert, vatex_bert, and tgif_bert respectively.

RIVRL on MSRVTT10K

Model Training and Evaluation

Run the following script to train and evaluate RIVRL network. Specifically, it will train RIVRL network and select a checkpoint that performs best on the validation set as the final model. Notice that we only save the best-performing checkpoint on the validation set to save disk space.

ROOTPATH=$HOME/VisualSearch

conda activate rivrl_env

# To train the model on the MSR-VTT, which the feature is resnext-101_resnet152-13k 
# Template:
./do_all_msrvtt.sh $ROOTPATH <split-Name> <useBert> <gpu-id>

# Example:
# Train RIVRL with the BERT on MV-Yu 
./do_all_msrvtt.sh $ROOTPATH msrvtt10yu 1 0

<split-Name> indicates different partitions of the dataset. msrvtt10yu, msrvtt10kmsrvtt10kmiech respectively denotes the partition of MV-Yu, MV-Miech and MV-Xu. <useBert> indicates whether training with BERT as additional text feature. 1 means using the BERT feature, while 0 indicates we do not use it. <gpu-id> is the index of the GPU where we train on.

Evaluation using Provided Checkpoints

Run the following script to download and evaluate our trained checkpoints. The trained checkpoints can also be downloaded from Baidu pan (url, password:wb3c).

ROOTPATH=$HOME/VisualSearch

# download trained checkpoints
wget -P $ROOTPATH http://8.210.46.84:8787/rivrl/best_model/msrvtt/<best_model>.pth.tar
# <best_model> is mv_yu_best, mv_yu_Bert_best, mv_miech_best, mv_miech_Bert_best, mv_xu_best, or mv_xu_Bert_best.
tar zxf $ROOTPATH/<best_model>.pth.tar -C $ROOTPATH

# evaluate on MSR-VTT
# Template:
./do_test.sh $ROOTPATH <split-Name> $MODELDIR <gpu-id>
# $MODELDIR is the path of checkpoints, $ROOTPATH/.../runs_0

# Example:
# evaluate on MV-Yu
./do_test.sh $ROOTPATH msrvtt10kyu $MODELDIR 0

Expected Performance

The expected performance and corresponding pre-trained checkpoints of RIVRL on MSR-VTT is as follows. Notice that due to random factors in SGD based training, the numbers differ slightly from those reported in the paper.

DataSetSplitsBERTText-to-Video Retrieval SumR Pre-trained Checkpoints
R@1 R@5 R@10 MedR mAP
MSR-VTT MV-Yu w/o 24.251.563.8536.86 139.5 mv_yu_best.pth.tar
with 27.959.371.3442.0 158.4 mv_yu_Bert_best.pth.tar
MV-Miech w/o 25.353.667.0438.5 145.9 mv_miech_best.pth.tar
with 26.256.668.2439.92 151.0 mv_miech_Bert_best.pth.tar
MV-Xu w/o 12.933.044.61423.07 90.5 mv_xu_best.pth.tar
with 13.734.646.41324.19 94.6 mv_xu_Bert_best.pth.tar

RIVRL on VATEX

Model Training and Evaluation

Run the following script to train and evaluate RIVRL network on VATEX.

ROOTPATH=$HOME/VisualSearch

conda activate rivrl_env

# To train the model on the VATEX
./do_all_vatex.sh $ROOTPATH <useBert> <gpu-id>

Expected Performance

Run the following script to download and evaluate our trained model on the VATEX from Baidu pan (url, password:wb3c).

ROOTPATH=$HOME/VisualSearch

# download trained checkpoints and evaluate 
wget -P $ROOTPATH http://8.210.46.84:8787/rivrl/best_model/vatex/<best_model>.pth.tar
# <best_model> is vatex_best or vatex_Bert_best
tar zxf $ROOTPATH/<best_model>.pth.tar -C $ROOTPATH

# evaluate on VATEX
./do_test.sh $ROOTPATH vatex $MODELDIR <gpu-id>
# $MODELDIR is the path of checkpoints, $ROOTPATH/.../runs_0

The expected performance and corresponding pre-trained checkpoints of RIVRL on VATEX is as follows.

DataSetSplitsBERTText-to-Video Retrieval SumRPre-trained Checkpoints
R@1 R@5 R@10 MedR mAP
VATEX w/o 39.476.184.8255.3 200.4 vatex_best.pth.tar
with 39.176.785.4255.4 201.0 vatex_Bert_best.pth.tar

RIVRL on TGIF

Model Training and Evaluation

Run the following script to train and evaluate RIVRL network on TGIF.

ROOTPATH=$HOME/VisualSearch

conda activate rivrl_env

# To train the model on the TGIF-Li
./do_all_tgif_li.sh $ROOTPATH <useBert> <gpu-id>

# To train the model on the TGIF-Chen
./do_all_tgif_chen.sh $ROOTPATH <useBert> <gpu-id>

Expected Performance

Run the following script to download and evaluate our trained model on the TGIF from Baidu pan (url, password:wb3c).

ROOTPATH=$HOME/VisualSearch

# download trained checkpoints 
wget -P $ROOTPATH http://8.210.46.84:8787/rivrl/best_model/tgif/<best_model>.pth.tar
# <best_model> is tgif_li_best, tgif_li_Bert_best, tgif_chen_best and tgif_chen_Bert_best, respectively.
tar zxf $ROOTPATH/<best_model>.pth.tar -C $ROOTPATH

# evaluate on the TGIF-Li
./do_test.sh $ROOTPATH tgif-li $MODELDIR <gpu-id>

# evaluate on the TGIF-Chen
./do_test.sh $ROOTPATH tgif-chen $MODELDIR <gpu-id>
# $MODELDIR is the path of checkpoints, $ROOTPATH/.../runs_0

The expected performance and corresponding pre-trained checkpoints of RIVRL on TGIF is as follows.

DataSetSplitsBERTText-to-Video Retrieval SumRPre-trained Checkpoints
R@1 R@5 R@10 MedR mAP
TGIF TGIF-Li w/o 11.325.333.63418.7 70.3 tgif_li_best.pth.tar
with 12.126.635.12919.75 73.8 tgif_li_Bert_best.pth.tar
TGIF-Chen w/o 6.416.122.49111.81 44.9 tgif_chen_best.pth.tar
with 6.817.223.57912.45 47.4 tgif_chen_Bert_best.pth.tar

Reference

If you find the package useful, please consider citing our paper:

@article{dong2022reading,
  title={Reading-strategy Inspired Visual Representation Learning for Text-to-Video Retrieval},
  author={Dong, Jianfeng and Wang, Yabing and Chen, Xianke and Qu, Xiaoye and Li, Xirong and He, Yuan and Wang, Xun},
  journal={IEEE Transactions on Circuits and Systems for Video Technology},
  year={2022}
}

About

Source code of our TCSVT'22 paper Reading-strategy Inspired Visual Representation Learning for Text-to-Video Retrieval

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published