forked from zyqin19/PROTOSEG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
96 lines (79 loc) · 3.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Main Testing Script
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import os
import random
import numpy as np
import argparse
import collections
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
from pcr.models import build_model
from pcr.datasets import build_dataset
from pcr.datasets.utils import collate_fn
from pcr.utils.config import Config, DictAction
from pcr.utils.logger import get_root_logger
from pcr.utils.env import get_random_seed, set_seed
from pcr.engines.test import TEST
def get_parser():
parser = argparse.ArgumentParser(description='PCR Test Process')
parser.add_argument('--config-file', default="configs/s3dis/semseg-ptv2m2-0-base.py",
metavar="FILE", help="path to config file")
parser.add_argument('--options', nargs='+', action=DictAction, help='custom options')
args = parser.parse_args()
return args
def main():
args = get_parser()
args.options = {"save_path": '/data/qinzheyun/PTV2/exp/s3dis/semseg-ptv2m2-0-prototype-3',
"weight": '/data/qinzheyun/PTV2/exp/s3dis/semseg-ptv2m2-0-prototype-3/model/model_best.pth'}
# config_parser
cfg = Config.fromfile(args.config_file)
if args.options is not None:
cfg.merge_from_dict(args.options)
if cfg.seed is None:
cfg.seed = get_random_seed()
os.makedirs(cfg.save_path, exist_ok=True)
# default_setup
set_seed(cfg.seed)
cfg.batch_size_val_per_gpu = cfg.batch_size_test # TODO: add support to multi gpu test
cfg.num_worker_per_gpu = cfg.num_worker # TODO: add support to multi gpu test
# tester init
weight_name = os.path.basename(cfg.weight).split(".")[0]
logger = get_root_logger(log_file=os.path.join(cfg.save_path, "test-{}.log".format(weight_name)))
logger.info("=> Loading config ...")
logger.info(f"Save path: {cfg.save_path}")
logger.info(f"Config:\n{cfg.pretty_text}")
# build model
logger.info("=> Building model ...")
model = build_model(cfg.model).cuda()
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Num params: {n_parameters}")
# build dataset
logger.info("=> Building test dataset & dataloader ...")
test_dataset = build_dataset(cfg.data.test)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=cfg.batch_size_val_per_gpu,
shuffle=False,
num_workers=cfg.num_worker_per_gpu,
pin_memory=True,
collate_fn=collate_fn)
# load checkpoint
if os.path.isfile(cfg.weight):
checkpoint = torch.load(cfg.weight)
state_dict = checkpoint['state_dict']
new_state_dict = collections.OrderedDict()
for k, v in state_dict.items():
name = k[7:] # module.xxx.xxx -> xxx.xxx
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=True)
logger.info("=> loaded weight '{}' (epoch {})".format(cfg.weight, checkpoint['epoch']))
cfg.epochs = checkpoint['epoch'] # TODO: move to self
else:
raise RuntimeError("=> no checkpoint found at '{}'".format(cfg.weight))
TEST.build(cfg.test)(cfg, test_loader, model)
if __name__ == '__main__':
main()