-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtrain.py
108 lines (82 loc) · 3.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import random
from collections import Counter
from functools import reduce
import torch
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
import data.aligned_conc_dataset as dataset
import util.utils as util
from config.default_config import DefaultConfig
from config.resnet_sunrgbd_config import RESNET_SUNRGBD_CONFIG
from data import DataProvider
from model.models import create_model
cfg = DefaultConfig()
args = {
'resnet_sunrgbd': RESNET_SUNRGBD_CONFIG().args(),
}
# Setting random seed
if cfg.MANUAL_SEED is None:
cfg.MANUAL_SEED = random.randint(1, 10000)
random.seed(cfg.MANUAL_SEED)
torch.manual_seed(cfg.MANUAL_SEED)
# args for different backbones
cfg.parse(args['resnet_sunrgbd'])
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.GPU_IDS
device_ids = torch.cuda.device_count()
print('device_ids:', device_ids)
project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1])
util.mkdir('logs')
# data
train_dataset = dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_TRAIN, transform=transforms.Compose([
dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
dataset.RandomHorizontalFlip(),
dataset.ToTensor(),
dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))
val_dataset = dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_VAL, transform=transforms.Compose([
dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
dataset.ToTensor(),
dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))
batch_size_val = cfg.BATCH_SIZE
unlabeled_loader = None
if cfg.UNLABELED:
unlabeled_dataset = dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_UNLABELED, transform=transforms.Compose([
dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
dataset.RandomHorizontalFlip(),
dataset.ToTensor(),
dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
)]), labeled=False)
unlabeled_loader = DataProvider(cfg, dataset=unlabeled_dataset)
train_loader = DataProvider(cfg, dataset=train_dataset)
val_loader = DataProvider(cfg, dataset=val_dataset, batch_size=batch_size_val, shuffle=False)
# class weights
num_classes_train = list(Counter([i[1] for i in train_loader.dataset.imgs]).values())
cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train)
writer = SummaryWriter(log_dir=cfg.LOG_PATH) # tensorboard
model = create_model(cfg, writer)
model.set_data_loader(train_loader, val_loader, unlabeled_loader)
def train():
if cfg.RESUME:
checkpoint_path = os.path.join(cfg.CHECKPOINTS_DIR, cfg.RESUME_PATH)
checkpoint = model.load_checkpoint(model.net, checkpoint_path, keep_kw_module=False, keep_fc=True)
load_epoch = checkpoint['epoch']
cfg.START_EPOCH = load_epoch
if cfg.INIT_EPOCH:
# just load pretrained parameters
print('load checkpoint from another source')
cfg.START_EPOCH = 1
print('>>> task path is {0}'.format(project_name))
# train
model.train_parameters(cfg)
print('save model ...')
model_filename = '{0}_{1}_{2}.pth'.format(cfg.MODEL, cfg.WHICH_DIRECTION, cfg.NITER_TOTAL)
model.save_checkpoint(cfg.NITER_TOTAL, model_filename)
if writer is not None:
writer.close()
if __name__ == '__main__':
train()