-
Notifications
You must be signed in to change notification settings - Fork 2
/
loader.py
358 lines (316 loc) · 13.7 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import os
import copy
import PIL
import torch
import torchvision
import numpy as np
import math
import logging
from typing import List
LOG = logging.getLogger(__name__)
def define_path(use_jaad=True, use_pie=True, use_titan=True):
"""
Define the correct paths to datasets'annotations and images
"""
all_anns_paths = {'JAAD': {'anns': '../../DATA/annotations/JAAD/JAAD_DATA.pkl',
'split': '../../DATA/annotations/JAAD/splits'},
'PIE': {'anns': '../../DATA/annotations/PIE/PIE_DATA.pkl'},
'TITAN': {'anns': '../../DATA/annotations/TITAN/titan_0_4',
'split':'../../DATA/annotations/TITAN/splits' }
}
all_image_dir = {'JAAD': '../../DATA/JAAD/images/',
'PIE': '../../DATA/PIE/images/',
'TITAN': '../../DATA/TITAN/images_anonymized/'
}
anns_paths = {}
image_dir = {}
if use_jaad:
anns_paths['JAAD'] = all_anns_paths['JAAD']
image_dir['JAAD'] = all_image_dir['JAAD']
if use_pie:
anns_paths['PIE'] = all_anns_paths['PIE']
image_dir['PIE'] = all_image_dir['PIE']
if use_titan:
anns_paths['TITAN'] = all_anns_paths['TITAN']
image_dir['TITAN'] = all_image_dir['TITAN']
return anns_paths, image_dir
class ImageList(torch.utils.data.Dataset):
"""
Basic dataloader for images
"""
def __init__(self, image_paths, preprocess=None):
self.image_paths = image_paths
self.preprocess = preprocess
def __getitem__(self, index):
image_path = self.image_paths[index]
with open(image_path, 'rb') as f:
image = PIL.Image.open(f).convert('RGB')
if self.preprocess is not None:
image = self.preprocess(image)
return image
def __len__(self):
return len(self.image_paths)
class MultiLoader:
# Class for loading data from mulitple datasets
last_task_index = None
def __init__(self, loaders: List[torch.utils.data.DataLoader],
weights=None, n_batches=None):
self.loaders = loaders
self._weights = weights
if self._weights is None:
self._weights = [1.0 / len(loaders) for _ in range(len(loaders))]
elif len(self._weights) == len(loaders) - 1:
self._weights.append(1.0 - sum(self._weights))
elif len(self._weights) == len(loaders):
pass
else:
raise Exception('invalid dataset weights: {}'.format(self._weights))
assert all(w > 0.0 for w in self._weights)
sum_w = sum(self._weights)
# normalize weights between datasets
self._weights = [w / sum_w for w in self._weights]
LOG.info('dataset weights: %s', self._weights)
# set the total number of batches in one epoch
self.n_batches = int(min(len(l) / w for l, w in zip(loaders, self._weights)))
if n_batches is not None:
self.n_batches = min(self.n_batches, n_batches)
def __iter__(self):
loader_iters = [iter(l) for l in self.loaders]
# counter of loaded batches for each dataset
n_loaded = [0 for _ in self.loaders]
while True:
# select loader for one iteration
loader_index = int(np.argmin([n / w for n, w in zip(n_loaded, self._weights)]))
next_batch = next(loader_iters[loader_index], None)
if next_batch is None:
break
n_loaded[loader_index] += 1
MultiLoader.last_task_index = loader_index
# generator
yield next_batch
# termination
if sum(n_loaded) >= self.n_batches:
break
def __len__(self):
return self.n_batches
class FrameDataset(torch.utils.data.Dataset):
def __init__(self, samples, image_dir, preprocess=None):
self.samples = samples
self.image_dir = image_dir
self.preprocess = preprocess
def __getitem__(self, index):
ids = list(self.samples.keys())
idx = ids[index]
frame = self.samples[idx]['frame']
bbox = copy.deepcopy(self.samples[idx]['bbox'])
source = self.samples[idx]["source"]
anns = {'bbox': bbox, 'source': source}
TTE = self.samples[idx]["TTE"]
if 'trans_label' in list(self.samples[idx].keys()):
label = self.samples[idx]['trans_label']
else:
label = None
if 'behavior' in list(self.samples[idx].keys()):
behavior = self.samples[idx]['behavior']
else:
behavior = [-1,-1,-1,-1] # no behavior annotations
if 'attributes' in list(self.samples[idx].keys()):
attributes = self.samples[idx]['attributes'] # scene attributes
else:
attributes = [-1,-1,-1,-1,-1,-1]
image_path = None
# image paths
if source == "JAAD":
vid = self.samples[idx]['video_number']
image_path = os.path.join(self.image_dir['JAAD'], vid, '{:05d}.png'.format(frame))
elif source == "PIE":
vid = self.samples[idx]['video_number']
sid = self.samples[idx]['set_number']
image_path = os.path.join(self.image_dir['PIE'], sid, vid, '{:05d}.png'.format(frame))
elif source == "TITAN":
vid = self.samples[idx]['video_number']
image_path = os.path.join(self.image_dir['TITAN'], vid, 'images', '{:06}.png'.format(frame))
with open(image_path, 'rb') as f:
img = PIL.Image.open(f).convert('RGB')
if self.preprocess is not None:
img, anns = self.preprocess(img, anns)
img_tensor = torchvision.transforms.ToTensor()(img)
if label is not None:
label = torch.tensor(label)
label = label.to(torch.float32)
if math.isnan(TTE):
pass
else:
TTE = round(self.samples[idx]["TTE"],2)
TTE = torch.tensor(TTE).to(torch.float32)
attributes = torch.tensor(attributes).to(torch.float32)
sample = {'image': img_tensor, 'bbox': anns['bbox'], 'id': idx,
'label': label, 'source': source, 'TTE': TTE,
'attributes': attributes, 'behavior': behavior
}
return sample
def __len__(self):
return len(self.samples.keys())
class SequenceDataset(torch.utils.data.Dataset):
"""
Basic dataloader for loading sequence/history samples
"""
def __init__(self, samples, image_dir, preprocess=None):
"""
:params: samples: transition history samples(dict)
image_dir: root dir for images extracted from video clips
preprocess: optional preprocessing on image tensors and annotations
"""
self.samples = samples
self.image_dir = image_dir
self.preprocess = preprocess
def __getitem__(self, index):
ids = list(self.samples.keys())
idx = ids[index]
frames = self.samples[idx]['frame']
bbox = copy.deepcopy(self.samples[idx]['bbox'])
source = self.samples[idx]["source"]
action = self.samples[idx]['action']
TTE = round(self.samples[idx]["TTE"],2)
if 'trans_label' in list(self.samples[idx].keys()):
label = self.samples[idx]['trans_label']
else:
label = None
bbox_new= []
image_path = None
# image paths
img_tensors = []
for i in range(len(frames)):
anns = {'bbox': bbox[i], 'source': source}
if source == "JAAD":
vid = self.samples[idx]['video_number']
image_path = os.path.join(self.image_dir['JAAD'], vid, '{:05d}.png'.format(frames[i]))
elif source == "PIE":
vid = self.samples[idx]['video_number']
sid = self.samples[idx]['set_number']
image_path = os.path.join(self.image_dir['PIE'], sid, vid, '{:05d}.png'.format(frames[i]))
elif source == "TITAN":
vid = self.samples[idx]['video_number']
image_path = os.path.join(self.image_dir['TITAN'], vid, 'images', '{:06}.png'.format(frames[i]))
with open(image_path, 'rb') as f:
img = PIL.Image.open(f).convert('RGB')
if self.preprocess is not None:
img, anns = self.preprocess(img, anns)
img_tensors.append(torchvision.transforms.ToTensor()(img))
bbox_new.append(anns['bbox'])
img_tensors = torch.stack(img_tensors)
if label is not None:
label = torch.tensor(label)
label = label.to(torch.float32)
sample = {'image': img_tensors, 'bbox': bbox_new, 'action': action, 'id': idx, 'label': label, 'source': source, 'TTE': TTE }
return sample
def __len__(self):
return len(self.samples.keys())
class PaddedSequenceDataset(torch.utils.data.Dataset):
"""
Dataloader for loading sequence/history samples,
all sequences are padded to unify the length
"""
def __init__(self, samples, image_dir, padded_length=10, preprocess=None, hflip_p=0.0):
"""
:params: samples: transition history samples(dict)
image_dir: root dir for images extracted from video clips
padded_length: length of each sequence after padded
preprocess: optional preprocessing on image tensors and annotations
"""
self.samples = samples
self.image_dir = image_dir
self.preprocess = preprocess
self.padded_length = padded_length
self.hflip_p = hflip_p
def __getitem__(self, index):
ids = list(self.samples.keys())
idx = ids[index]
frames = self.samples[idx]['frame']
bbox = copy.deepcopy(self.samples[idx]['bbox'])
source = self.samples[idx]["source"]
action = self.samples[idx]['action']
TTE = self.samples[idx]["TTE"]
if source == "PIE":
set_number = self.samples[idx]['set_number']
else:
set_number = None
if 'trans_label' in list(self.samples[idx].keys()):
label = self.samples[idx]['trans_label']
else:
label = None
if 'behavior' in list(self.samples[idx].keys()):
behavior = self.samples[idx]['behavior']
else:
behavior = [-1,-1,-1,-1]
if 'attributes' in list(self.samples[idx].keys()):
attributes = self.samples[idx]['attributes']
else:
attributes = [-1,-1,-1,-1,-1,-1]
bbox_new = []
bbox_ped_new = []
image_path = None
# image paths
img_tensors = []
hflip = True if float(torch.rand(1).item()) < self.hflip_p else False
for i in range(len(frames)):
anns = {'bbox': bbox[i], 'source': source}
if source == "JAAD":
vid = self.samples[idx]['video_number']
image_path = os.path.join(self.image_dir['JAAD'], vid, '{:05d}.png'.format(frames[i]))
elif source == "PIE":
vid = self.samples[idx]['video_number']
sid = self.samples[idx]['set_number']
image_path = os.path.join(self.image_dir['PIE'], sid, vid, '{:05d}.png'.format(frames[i]))
elif source == "TITAN":
vid = self.samples[idx]['video_number']
image_path = os.path.join(self.image_dir['TITAN'], vid, 'images', '{:06}.png'.format(frames[i]))
with open(image_path, 'rb') as f:
img = PIL.Image.open(f).convert('RGB')
if hflip:
img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
w, h = img.size
x_max = w - anns['bbox'][0]
x_min = w - anns['bbox'][2]
anns['bbox'][0] = x_min
anns['bbox'][2] = x_max
anns['bbox_ped'] = copy.deepcopy(anns['bbox'])
if self.preprocess is not None:
img, anns = self.preprocess(img, anns)
img_tensors.append(torchvision.transforms.ToTensor()(img))
bbox_new.append(anns['bbox'])
bbox_ped_new.append(anns['bbox_ped'])
img_tensors = torch.stack(img_tensors)
imgs_size = img_tensors.size()
img_tensors_padded = torch.zeros((self.padded_length, imgs_size[1], imgs_size[2], imgs_size[3]))
img_tensors_padded[:imgs_size[0], :, :, :] = img_tensors
bbox_new_padded = copy.deepcopy(bbox_new)
bbox_ped_new_padded = copy.deepcopy(bbox_ped_new)
action_padded = copy.deepcopy(action)
behavior_padded = copy.deepcopy(behavior)
for i in range(imgs_size[0],self.padded_length):
bbox_new_padded.append([0,0,0,0])
bbox_ped_new_padded.append([0,0,0,0])
action_padded.append(-1)
behavior_padded.append([-1,-1,-1,-1])
# seq_len = torch.squeeze(torch.LongTensor(imgs_size[0]))
seq_len = imgs_size[0]
if label is not None:
label = torch.tensor(label)
label = label.to(torch.float32)
TTE_tag = -1
if math.isnan(TTE):
pass
else:
TTE = round(self.samples[idx]["TTE"],2)
TTE = torch.tensor(TTE).to(torch.float32)
TTE_tag = torch.tensor(TTE_tag)
TTE_tag = TTE_tag.to(torch.float32)
attributes = torch.tensor(attributes).to(torch.float32)
sample = {'image': img_tensors_padded, 'bbox': bbox_new_padded, 'bbox_ped': bbox_ped_new_padded,
'seq_length': seq_len, 'action': action_padded, 'id': idx, 'label': label,
'source': source, 'TTE': TTE,
'behavior': behavior_padded, 'attributes': attributes}
return sample
def __len__(self):
return len(self.samples.keys())