-
Notifications
You must be signed in to change notification settings - Fork 89
/
train.py
87 lines (72 loc) · 2.48 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import yaml
import os
import torch
import torch.nn as nn
from utils.dataloader import get_dataloader_and_vocab
from utils.trainer import Trainer
from utils.helper import (
get_model_class,
get_optimizer_class,
get_lr_scheduler,
save_config,
save_vocab,
)
def train(config):
os.makedirs(config["model_dir"])
train_dataloader, vocab = get_dataloader_and_vocab(
model_name=config["model_name"],
ds_name=config["dataset"],
ds_type="train",
data_dir=config["data_dir"],
batch_size=config["train_batch_size"],
shuffle=config["shuffle"],
vocab=None,
)
val_dataloader, _ = get_dataloader_and_vocab(
model_name=config["model_name"],
ds_name=config["dataset"],
ds_type="valid",
data_dir=config["data_dir"],
batch_size=config["val_batch_size"],
shuffle=config["shuffle"],
vocab=vocab,
)
vocab_size = len(vocab.get_stoi())
print(f"Vocabulary size: {vocab_size}")
model_class = get_model_class(config["model_name"])
model = model_class(vocab_size=vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer_class = get_optimizer_class(config["optimizer"])
optimizer = optimizer_class(model.parameters(), lr=config["learning_rate"])
lr_scheduler = get_lr_scheduler(optimizer, config["epochs"], verbose=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trainer = Trainer(
model=model,
epochs=config["epochs"],
train_dataloader=train_dataloader,
train_steps=config["train_steps"],
val_dataloader=val_dataloader,
val_steps=config["val_steps"],
criterion=criterion,
optimizer=optimizer,
checkpoint_frequency=config["checkpoint_frequency"],
lr_scheduler=lr_scheduler,
device=device,
model_dir=config["model_dir"],
model_name=config["model_name"],
)
trainer.train()
print("Training finished.")
trainer.save_model()
trainer.save_loss()
save_vocab(vocab, config["model_dir"])
save_config(config, config["model_dir"])
print("Model artifacts saved to folder:", config["model_dir"])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to yaml config')
args = parser.parse_args()
with open(args.config, 'r') as stream:
config = yaml.safe_load(stream)
train(config)