-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtraining_ModelNet.py
136 lines (120 loc) · 5.04 KB
/
training_ModelNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import time
import shutil
from datasets.ModelNet import ModelNetDataset
from utils.config import Config
from trainer_cls import Trainer
from models.KPCNN import KPCNN
from datasets.dataloader import get_dataloader
from torch import optim
from torch import nn
import torch
class ModelNetConfig(Config):
# dataset
dataset = 'ModelNet'
num_classes = 40
first_subsampling_dl = 0.02
in_features_dim = 4
data_train_dir = "./data/modelnet40_normal_resampled/"
data_test_dir = "./data/modelnet40_normal_resampled/"
train_batch_size = 8
test_batch_size = 8
# model
architecture = ['simple',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb_strided',
'resnetb',
'global_average'
]
dropout = 0.5
resume = None
use_batch_norm = True
batch_norm_momentum = 0.02
# https://github.com/pytorch/examples/issues/289 pytorch bn momentum 0.02 == tensorflow bn momentum 0.98
# kernel point convolution
KP_influence = 'linear'
KP_extent = 1.0
convolution_mode = 'sum'
# training
max_epoch = 200
learning_rate = 5e-3
momentum = 0.98
exp_gamma = 0.1 ** (1 / 80)
exp_interval = 1
class Args(object):
def __init__(self, config):
is_test = False
if is_test:
self.experiment_id = "KPConvNet" + time.strftime('%m%d%H%M') + 'Test'
else:
self.experiment_id = "KPConvNet" + time.strftime('%m%d%H%M')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.verbose = True
# snapshot
self.snapshot_interval = 5
snapshot_root = f'snapshot/{config.dataset}_{self.experiment_id}'
tensorboard_root = f'tensorboard/{config.dataset}_{self.experiment_id}'
os.makedirs(snapshot_root, exist_ok=True)
os.makedirs(tensorboard_root, exist_ok=True)
shutil.copy2(os.path.join('.', 'training_ModelNet.py'), os.path.join(snapshot_root, 'train.py'))
shutil.copy2(os.path.join('datasets', 'ModelNet.py'), os.path.join(snapshot_root, 'dataset.py'))
shutil.copy2(os.path.join('datasets', 'dataloader.py'), os.path.join(snapshot_root, 'dataloader.py'))
self.save_dir = os.path.join(snapshot_root, 'models/')
self.result_dir = os.path.join(snapshot_root, 'results/')
self.tboard_dir = tensorboard_root
# dataset & dataloader
self.train_set = ModelNetDataset(root=config.data_train_dir,
split='train',
first_subsampling_dl=config.first_subsampling_dl,
config=config,
)
self.test_set = ModelNetDataset(root=config.data_test_dir,
split='test',
first_subsampling_dl=config.first_subsampling_dl,
config=config,
)
self.train_loader = get_dataloader(dataset=self.train_set,
batch_size=config.train_batch_size,
shuffle=True,
num_workers=config.train_batch_size,
)
self.test_loader = get_dataloader(dataset=self.test_set,
batch_size=config.test_batch_size,
shuffle=False,
num_workers=config.test_batch_size,
)
print("Training set size:", self.train_loader.dataset.__len__())
print("Test set size:", self.test_loader.dataset.__len__())
# model
self.model = KPCNN(config)
self.resume = config.resume
# optimizer
self.start_epoch = 0
self.epoch = config.max_epoch
self.optimizer = optim.SGD(self.model.parameters(), lr=config.learning_rate, momentum=config.momentum, weight_decay=1e-6)
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=config.exp_gamma)
self.scheduler_interval = config.exp_interval
# evaluate
self.evaluate_interval = 1
self.evaluate_metric = nn.CrossEntropyLoss(reduction='mean')
self.check_args()
def check_args(self):
"""checking arguments"""
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
if not os.path.exists(self.result_dir):
os.makedirs(self.result_dir)
if not os.path.exists(self.tboard_dir):
os.makedirs(self.tboard_dir)
return self
if __name__ == '__main__':
config = ModelNetConfig()
args = Args(config)
trainer = Trainer(args)
trainer.train()