-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.sh
156 lines (153 loc) · 5.11 KB
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
while getopts ":p:t:d:" opt
do
case $opt in
p)
PRETRAINED_MODEL_PATH="$OPTARG"
;;
t)
PRETRAINED_MODEL_TYPE="$OPTARG"
;;
d)
DATASET="$OPTARG"
;;
?)
echo "未知参数"
exit 1;;
esac
done
PROJECT_PATH='.'
if [ "$PRETRAINED_MODEL_TYPE" == "dialogved_standard" ]; then
echo '-------- model type: dialogved standard --------'
ARCH=ngram_transformer_prophet_vae_standard
elif [ "$PRETRAINED_MODEL_TYPE" == "dialogved_large" ]; then
echo '-------- model type: dialogved large --------'
ARCH=ngram_transformer_prophet_vae_large
elif [ "$PRETRAINED_MODEL_TYPE" == "dialogved_seq2seq" ]; then
echo '-------- model type: dialogved seq2seq --------'
ARCH=ngram_transformer_prophet_seq2seq
else
echo 'model type '"$PRETRAINED_MODEL_TYPE"' not found!'
exit 1
fi
if [ "$DATASET" == "dailydialog" ]; then
echo '-------- fine-tune on dataset: dailydialog --------'
NUM_WORKERS=10
CRITERION=ved_loss
TASK=ved_translate
USER_DIR=${PROJECT_PATH}/src
DATA_DIR=${PROJECT_PATH}/data/finetune/dailydialog
SAVE_DIR=${DATA_DIR}/checkpoints
TB_LOGDIR=${DATA_DIR}/tensorboard
fairseq-train \
${DATA_DIR}/binary \
--fp16 \
--user-dir ${USER_DIR} --task ${TASK} --arch ${ARCH} \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 1.0 \
--lr 0.0003 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 2000 \
--criterion $CRITERION --label-smoothing 0.1 \
--update-freq 4 --max-tokens 4500 --max-sentences 16 \
--num-workers ${NUM_WORKERS} \
--dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.0 --weight-decay 0.01 \
--encoder-layer-drop 0.0 \
--save-dir ${SAVE_DIR} \
--max-epoch 10 \
--keep-last-epochs 10 \
--max-source-positions 512 \
--max-target-positions 128 \
--kl-loss-weight 1.0 \
--target-kl 5.0 \
--cls-bow-loss-weight 0.0 \
--latent-bow-loss-weight 1.0 \
--masked-lm-loss-weight 0.0 \
--tensorboard-logdir ${TB_LOGDIR} \
--dataset-impl mmap \
--empty-cache-freq 64 \
--seed 1 \
--skip-invalid-size-inputs-valid-test \
--distributed-no-spawn \
--ddp-backend no_c10d \
--load-from-pretrained-model "${PRETRAINED_MODEL_PATH}"
elif [ "$DATASET" == "dstc7avsd" ]; then
echo '-------- fine-tune on dataset: dstc7avsd --------'
NUM_WORKERS=10
CRITERION=ved_loss
TASK=ved_translate
USER_DIR=${PROJECT_PATH}/src
DATA_DIR=${PROJECT_PATH}/data/finetune/dstc7avsd
SAVE_DIR=${DATA_DIR}/checkpoints
TB_LOGDIR=${DATA_DIR}/tensorboard
fairseq-train \
${DATA_DIR}/binary \
--fp16 \
--user-dir ${USER_DIR} --task ${TASK} --arch ${ARCH} \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 1.0 \
--lr 0.0003 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 2000 \
--criterion $CRITERION --label-smoothing 0.1 \
--update-freq 4 --max-tokens 4500 --max-sentences 16 \
--num-workers ${NUM_WORKERS} \
--dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.0 --weight-decay 0.01 \
--encoder-layer-drop 0.0 \
--save-dir ${SAVE_DIR} \
--max-epoch 10 \
--keep-last-epochs 10 \
--max-source-positions 512 \
--max-target-positions 128 \
--kl-loss-weight 1.0 \
--target-kl 5.0 \
--cls-bow-loss-weight 0.0 \
--latent-bow-loss-weight 1.0 \
--masked-lm-loss-weight 0.0 \
--tensorboard-logdir ${TB_LOGDIR} \
--dataset-impl mmap \
--empty-cache-freq 64 \
--seed 1 \
--skip-invalid-size-inputs-valid-test \
--distributed-no-spawn \
--ddp-backend no_c10d \
--load-from-pretrained-model "${PRETRAINED_MODEL_PATH}"
elif [ "$DATASET" == "personachat" ]; then
echo '-------- fine-tune on dataset: personachat --------'
NUM_WORKERS=10
CRITERION=ved_loss
TASK=ved_translate
USER_DIR=${PROJECT_PATH}/src
DATA_DIR=${PROJECT_PATH}/data/finetune/personachat
SAVE_DIR=${DATA_DIR}/checkpoints
TB_LOGDIR=${DATA_DIR}/tensorboard
fairseq-train \
${DATA_DIR}/binary \
--fp16 \
--user-dir ${USER_DIR} --task ${TASK} --arch ${ARCH} \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 1.0 \
--lr 0.0003 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 2000 \
--criterion $CRITERION --label-smoothing 0.1 \
--update-freq 4 --max-tokens 4500 --max-sentences 16 \
--num-workers ${NUM_WORKERS} \
--dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.0 --weight-decay 0.01 \
--encoder-layer-drop 0.0 \
--save-dir ${SAVE_DIR} \
--max-epoch 10 \
--keep-last-epochs 10 \
--max-source-positions 512 \
--max-target-positions 128 \
--kl-loss-weight 1.0 \
--target-kl 5.0 \
--cls-bow-loss-weight 0.0 \
--latent-bow-loss-weight 1.0 \
--masked-lm-loss-weight 0.0 \
--tensorboard-logdir ${TB_LOGDIR} \
--dataset-impl mmap \
--empty-cache-freq 64 \
--seed 1 \
--skip-invalid-size-inputs-valid-test \
--distributed-no-spawn \
--ddp-backend no_c10d \
--load-from-pretrained-model "${PRETRAINED_MODEL_PATH}"
else
# echo 'dataset not found!'
echo 'dataset '"$DATASET"' not found!'
exit 1
fi