-
Notifications
You must be signed in to change notification settings - Fork 19
/
train.py
executable file
·131 lines (106 loc) · 8.33 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
from sim_matrix.fusion_scores import fusion_scores
from train_titles import train_titles
from train_video import train_video
import torch.distributed as dist
def get_args(description='CLIP4Clip on Retrieval Task'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--do_test", action='store_true', help="Whether to run test on the dev set.")
parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='')
parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='')
parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path')
parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path')
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=20, help='')
parser.add_argument('--max_frames', type=int, default=100, help='')
parser.add_argument('--feature_framerate', type=int, default=1, help='')
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--resume_model", default=None, type=str, required=False, help="Resume train model.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.")
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
parser.add_argument("--rank", default=0, type=int, help="distribted training")
parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.')
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.")
parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.")
parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.")
parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2],
help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2],
help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.")
parser.add_argument('--slice_framepos', type=int, default=0, choices=[0, 1, 2],
help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.")
parser.add_argument('--linear_patch', type=str, default="2d", choices=["2d", "3d"],
help="linear projection of flattened patches.")
parser.add_argument('--sim_header', type=str, default="meanP",
choices=["meanP", "seqLSTM", "seqTransf", "tightTransf", "seqTransf_topk"],
help="choice a similarity header.")
parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version")
parser.add_argument("--strategy", default=2, type=int, help="Sampling strategies.")
### interaction
parser.add_argument('--interaction', type=str, default='dp', help="interaction type for retrieval.")
parser.add_argument('--wti_arch', type=int, default=2, help="select a architecture for weight branch")
parser.add_argument('--text_pool_type', type=str, default='clip_top1')
parser.add_argument("--k", default=1, type=int, help="topk caption.")
parser.add_argument("--generate_images", default=None, type=str, help="generate images path")
parser.add_argument("--freeze_text_encoder", action='store_true', help="whether freeze text encoder")
args = parser.parse_args()
if args.sim_header == "tightTransf":
args.loose_type = False
# Check paramenters
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_train and not args.do_eval and not args.do_test:
raise ValueError("At least one of `do_train` or `do_eval` or 'do_test' must be True.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
return args
def main():
args = get_args()
video_best_weight = train_video(args)
video_best_weight_list = [video_best_weight]
print('before==',video_best_weight_list)
dist.broadcast_object_list(video_best_weight_list, src=0)
print('after===',video_best_weight_list)
train_titles(args, video_best_weight_list[0])
if dist.get_rank() == 0:
fusion_scores()
if __name__ == "__main__":
main()