-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathdo_eval.py
154 lines (122 loc) · 5.9 KB
/
do_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import argparse
import logging
import os
from functools import partial
import numpy as np
import torch
from torch.utils.data import RandomSampler, SequentialSampler
from tqdm import tqdm
import structs
import utils
from modeling.vectornet import VectorNet
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
tqdm = partial(tqdm, dynamic_ncols=True)
def eval_instance_argoverse(batch_size, args, pred, mapping, file2pred, file2labels, DEs, iter_bar):
for i in range(batch_size):
a_pred = pred[i]
assert a_pred.shape == (6, args.future_frame_num, 2)
if args.argoverse2:
file_name = os.path.split(mapping[i]['file_name'])[1]
file2pred[file_name] = a_pred
else:
file_name_int = int(os.path.split(mapping[i]['file_name'])[1][:-4])
file2pred[file_name_int] = a_pred
if not args.do_test:
file2labels[file_name_int] = mapping[i]['origin_labels']
if not args.do_test:
DE = np.zeros([batch_size, args.future_frame_num])
for i in range(batch_size):
origin_labels = mapping[i]['origin_labels']
for j in range(args.future_frame_num):
DE[i][j] = np.sqrt((origin_labels[j][0] - pred[i, 0, j, 0]) ** 2 + (
origin_labels[j][1] - pred[i, 0, j, 1]) ** 2)
DEs.append(DE)
miss_rate = 0.0
if 0 in utils.method2FDEs:
FDEs = utils.method2FDEs[0]
miss_rate = np.sum(np.array(FDEs) > 2.0) / len(FDEs)
iter_bar.set_description('Iter (MR=%5.3f)' % (miss_rate))
def do_eval(args):
device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
print("Loading Evalute Dataset", args.data_dir)
if args.argoverse:
from dataset_argoverse import Dataset
eval_dataset = Dataset(args, args.eval_batch_size)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.eval_batch_size,
sampler=eval_sampler,
collate_fn=utils.batch_list_to_batch_tensors,
pin_memory=False)
model = VectorNet(args)
print('torch.cuda.device_count', torch.cuda.device_count())
logger.info("***** Recover model: %s *****", args.model_recover_path)
if args.model_recover_path is None:
raise ValueError("model_recover_path not specified.")
model_recover = torch.load(args.model_recover_path)
model.load_state_dict(model_recover)
if 'set_predict-train_recover' in args.other_params and 'complete_traj' in args.other_params:
model_recover = torch.load(args.other_params['set_predict-train_recover'])
utils.load_model(model.decoder.complete_traj_cross_attention, model_recover, prefix='decoder.complete_traj_cross_attention.')
utils.load_model(model.decoder.complete_traj_decoder, model_recover, prefix='decoder.complete_traj_decoder.')
model.to(device)
model.eval()
file2pred = {}
file2labels = {}
iter_bar = tqdm(eval_dataloader, desc='Iter (loss=X.XXX)')
DEs = []
length = len(iter_bar)
if args.argoverse2:
metrics = utils.PredictionMetrics()
argo_pred = structs.ArgoPred()
for step, batch in enumerate(iter_bar):
pred_trajectory, pred_score, _ = model(batch, device)
mapping = batch
batch_size = pred_trajectory.shape[0]
for i in range(batch_size):
assert pred_trajectory[i].shape == (6, args.future_frame_num, 2)
assert pred_score[i].shape == (6,)
argo_pred[mapping[i]['file_name']] = structs.MultiScoredTrajectory(pred_score[i].copy(), pred_trajectory[i].copy())
if args.argoverse2:
for i in range(batch_size):
from av2.datasets.motion_forecasting.eval.metrics \
import compute_ade, compute_fde, compute_brier_fde
forecasted_trajectories = pred_trajectory[i][:, :, :]
gt_trajectory = mapping[i]['gt_trajectory_global_coordinates'][:, :]
forecast_probabilities = np.exp(pred_score[i])
forecast_probabilities = forecast_probabilities * (1.0 / forecast_probabilities.sum())
assert forecasted_trajectories.shape == (6, 60, 2)
assert gt_trajectory.shape == (60, 2)
assert forecast_probabilities.shape == (6,)
ade = compute_ade(forecasted_trajectories, gt_trajectory)
fde = compute_fde(forecasted_trajectories, gt_trajectory)
brier_fde = compute_brier_fde(forecasted_trajectories, gt_trajectory, forecast_probabilities)
metrics.minADE.accumulate(ade.min())
metrics.minFDE.accumulate(fde.min())
metrics.MR.accumulate(fde.min() > 2.0)
metrics.brier_minFDE.accumulate(brier_fde.min())
else:
eval_instance_argoverse(batch_size, args, pred_trajectory, mapping, file2pred, file2labels, DEs, iter_bar)
if 'optimization' in args.other_params:
utils.select_goals_by_optimization(None, None, close=True)
if args.argoverse2:
import json
print('Metrics:')
print(json.dumps(metrics.serialize(), indent=4))
else:
from dataset_argoverse import post_eval
post_eval(args, file2pred, file2labels, DEs)
def main():
parser = argparse.ArgumentParser()
utils.add_argument(parser)
args: utils.Args = parser.parse_args()
utils.init(args, logger)
device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
logger.info("device: {}".format(device))
do_eval(args)
if __name__ == "__main__":
main()