forked from sovit-123/fasterrcnn-pytorch-training-pipeline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.py
89 lines (81 loc) · 2.98 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from torch_utils.engine import evaluate
from datasets import (
create_valid_dataset, create_valid_loader
)
from models.create_fasterrcnn_model import create_model
import torch
import argparse
import yaml
if __name__ == '__main__':
# Construct the argument parser.
parser = argparse.ArgumentParser()
parser.add_argument(
'-c', '--config',
default='data_configs/test_image_config.yaml',
help='(optional) path to the data config file'
)
parser.add_argument(
'-m', '--model', default='fasterrcnn_resnet50_fpn',
help='name of the model'
)
parser.add_argument(
'-mw', '--weights', default=None,
help='path to trained checkpoint weights if providing custom YAML file'
)
parser.add_argument(
'-ims', '--img-size', dest='img_size', default=512, type=int,
help='image size to feed to the network'
)
parser.add_argument(
'-w', '--workers', default=4, type=int,
help='number of workers for data processing/transforms/augmentations'
)
parser.add_argument(
'-b', '--batch-size', dest='batch_size', default=8, type=int,
help='batch size to load the data'
)
parser.add_argument(
'-d', '--device',
default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
help='computation/training device, default is GPU if GPU present'
)
args = vars(parser.parse_args())
# Load the data configurations
with open(args['config']) as file:
data_configs = yaml.safe_load(file)
# Validation settings and constants.
try: # Use test images if present.
VALID_DIR_IMAGES = data_configs['TEST_DIR_IMAGES']
VALID_DIR_LABELS = data_configs['TEST_DIR_LABELS']
except: # Else use the validation images.
VALID_DIR_IMAGES = data_configs['VALID_DIR_IMAGES']
VALID_DIR_LABELS = data_configs['VALID_DIR_LABELS']
NUM_CLASSES = data_configs['NC']
CLASSES = data_configs['CLASSES']
NUM_WORKERS = args['workers']
DEVICE = args['device']
BATCH_SIZE = args['batch_size']
# Load the pretrained model
create_model = create_model[args['model']]
if args['weights'] is None:
model = create_model(num_classes=NUM_CLASSES, coco_model=True)
# Load weights.
if args['weights'] is not None:
model = create_model(num_classes=NUM_CLASSES, coco_model=False)
checkpoint = torch.load(args['weights'], map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE).eval()
# Model configurations
IMAGE_WIDTH = args['img_size']
IMAGE_HEIGHT = args['img_size']
valid_dataset = create_valid_dataset(
VALID_DIR_IMAGES, VALID_DIR_LABELS,
IMAGE_WIDTH, IMAGE_HEIGHT, CLASSES
)
valid_loader = create_valid_loader(valid_dataset, BATCH_SIZE, NUM_WORKERS)
coco_evaluator, stats = evaluate(
model,
valid_loader,
device=DEVICE,
classes=CLASSES,
)