-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtrain.sh
34 lines (30 loc) · 946 Bytes
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
LEARNING_DATASETS=$1
DATA_SPLIT=$2
EXP_NAME=$3
OUTPUT_DIR=$4
DATA_DIR=$5
DETR_CKPT="${DATA_DIR}/detr/detr_coco_sce.pth"
if [[ $DATA_SPLIT == "original_split" ]]
then
DETR_CKPT="${DATA_DIR}/detr/detr_coco.pth"
fi
# DETR components are frozen and rest of the model weights are finetuned
python -m exp.gpv.train_distr \
exp_name=$EXP_NAME \
output_dir=$OUTPUT_DIR \
data_dir=$DATA_DIR \
learning_datasets=$LEARNING_DATASETS \
task_configs.data_split=$DATA_SPLIT \
model.pretr_detr=$DETR_CKPT \
training.freeze=True
# Path to the checkpoint saved from the previous step
CKPT="${OUTPUT_DIR}/${EXP_NAME}/ckpts/model.pth"
# Finetune entire model including DETR weights
python -m exp.gpv.train_distr \
exp_name=$EXP_NAME \
output_dir=$OUTPUT_DIR \
data_dir=$DATA_DIR \
learning_datasets=$LEARNING_DATASETS \
task_configs.data_split=$DATA_SPLIT \
training.ckpt=$CKPT \
training.freeze=False