-
-
Notifications
You must be signed in to change notification settings - Fork 16.7k
/
Copy pathdataloaders.py
1378 lines (1206 loc) Β· 59.1 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Ultralytics π AGPL-3.0 License - https://ultralytics.com/license
"""Dataloaders and dataset utils."""
import contextlib
import glob
import hashlib
import json
import math
import os
import random
import shutil
import time
from itertools import repeat
from multiprocessing.pool import Pool, ThreadPool
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse
import numpy as np
import psutil
import torch
import torch.nn.functional as F
import torchvision
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm
from utils.augmentations import (
Albumentations,
augment_hsv,
classify_albumentations,
classify_transforms,
copy_paste,
letterbox,
mixup,
random_perspective,
)
from utils.general import (
DATASETS_DIR,
LOGGER,
NUM_THREADS,
TQDM_BAR_FORMAT,
check_dataset,
check_requirements,
check_yaml,
clean_str,
cv2,
is_colab,
is_kaggle,
segments2boxes,
unzip_file,
xyn2xy,
xywh2xyxy,
xywhn2xyxy,
xyxy2xywhn,
)
from utils.torch_utils import torch_distributed_zero_first
# Parameters
HELP_URL = "See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data"
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv("RANK", -1))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == "Orientation":
break
def get_hash(paths):
"""Generates a single SHA256 hash for a list of file or directory paths by combining their sizes and paths."""
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.sha256(str(size).encode()) # hash sizes
h.update("".join(paths).encode()) # hash paths
return h.hexdigest() # return hash
def exif_size(img):
"""Returns corrected PIL image size (width, height) considering EXIF orientation."""
s = img.size # (width, height)
with contextlib.suppress(Exception):
rotation = dict(img._getexif().items())[orientation]
if rotation in [6, 8]: # rotation 270 or 90
s = (s[1], s[0])
return s
def exif_transpose(image):
"""
Transpose a PIL image accordingly if it has an EXIF Orientation tag.
Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose().
:param image: The image to transpose.
:return: An image.
"""
exif = image.getexif()
orientation = exif.get(0x0112, 1) # default 1
if orientation > 1:
method = {
2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}.get(orientation)
if method is not None:
image = image.transpose(method)
del exif[0x0112]
image.info["exif"] = exif.tobytes()
return image
def seed_worker(worker_id):
"""
Sets the seed for a dataloader worker to ensure reproducibility, based on PyTorch's randomness notes.
See https://pytorch.org/docs/stable/notes/randomness.html#dataloader.
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# Inherit from DistributedSampler and override iterator
# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
class SmartDistributedSampler(distributed.DistributedSampler):
"""A distributed sampler ensuring deterministic shuffling and balanced data distribution across GPUs."""
def __iter__(self):
"""Yields indices for distributed data sampling, shuffled deterministically based on epoch and seed."""
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# determine the eventual size (n) of self.indices (DDP indices)
n = int((len(self.dataset) - self.rank - 1) / self.num_replicas) + 1 # num_replicas == WORLD_SIZE
idx = torch.randperm(n, generator=g)
if not self.shuffle:
idx = idx.sort()[0]
idx = idx.tolist()
if self.drop_last:
idx = idx[: self.num_samples]
else:
padding_size = self.num_samples - len(idx)
if padding_size <= len(idx):
idx += idx[:padding_size]
else:
idx += (idx * math.ceil(padding_size / len(idx)))[:padding_size]
return iter(idx)
def create_dataloader(
path,
imgsz,
batch_size,
stride,
single_cls=False,
hyp=None,
augment=False,
cache=False,
pad=0.0,
rect=False,
rank=-1,
workers=8,
image_weights=False,
quad=False,
prefix="",
shuffle=False,
seed=0,
):
"""Creates and returns a configured DataLoader instance for loading and processing image datasets."""
if rect and shuffle:
LOGGER.warning("WARNING β οΈ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = LoadImagesAndLabels(
path,
imgsz,
batch_size,
augment=augment, # augmentation
hyp=hyp, # hyperparameters
rect=rect, # rectangular batches
cache_images=cache,
single_cls=single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix,
rank=rank,
)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + seed + RANK)
return loader(
dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
drop_last=quad,
pin_memory=PIN_MEMORY,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
worker_init_fn=seed_worker,
generator=generator,
), dataset
class InfiniteDataLoader(dataloader.DataLoader):
"""
Dataloader that reuses workers.
Uses same syntax as vanilla DataLoader
"""
def __init__(self, *args, **kwargs):
"""Initializes an InfiniteDataLoader that reuses workers with standard DataLoader syntax, augmenting with a
repeating sampler.
"""
super().__init__(*args, **kwargs)
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
"""Returns the length of the batch sampler's sampler in the InfiniteDataLoader."""
return len(self.batch_sampler.sampler)
def __iter__(self):
"""Yields batches of data indefinitely in a loop by resetting the sampler when exhausted."""
for _ in range(len(self)):
yield next(self.iterator)
class _RepeatSampler:
"""
Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
"""Initializes a perpetual sampler wrapping a provided `Sampler` instance for endless data iteration."""
self.sampler = sampler
def __iter__(self):
"""Returns an infinite iterator over the dataset by repeatedly yielding from the given sampler."""
while True:
yield from iter(self.sampler)
class LoadScreenshots:
"""Loads and processes screenshots for YOLOv5 detection from specified screen regions using mss."""
def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
"""
Initializes a screenshot dataloader for YOLOv5 with specified source region, image size, stride, auto, and
transforms.
Source = [screen_number left top width height] (pixels)
"""
check_requirements("mss")
import mss
source, *params = source.split()
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
if len(params) == 1:
self.screen = int(params[0])
elif len(params) == 4:
left, top, width, height = (int(x) for x in params)
elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params)
self.img_size = img_size
self.stride = stride
self.transforms = transforms
self.auto = auto
self.mode = "stream"
self.frame = 0
self.sct = mss.mss()
# Parse monitor shape
monitor = self.sct.monitors[self.screen]
self.top = monitor["top"] if top is None else (monitor["top"] + top)
self.left = monitor["left"] if left is None else (monitor["left"] + left)
self.width = width or monitor["width"]
self.height = height or monitor["height"]
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
def __iter__(self):
"""Iterates over itself, enabling use in loops and iterable contexts."""
return self
def __next__(self):
"""Captures and returns the next screen frame as a BGR numpy array, cropping to only the first three channels
from BGRA.
"""
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
self.frame += 1
return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
class LoadImages:
"""YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`."""
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
"""Initializes YOLOv5 loader for images/videos, supporting glob patterns, directories, and lists of paths."""
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit()
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(Path(p).resolve())
if "*" in p:
files.extend(sorted(glob.glob(p, recursive=True))) # glob
elif os.path.isdir(p):
files.extend(sorted(glob.glob(os.path.join(p, "*.*")))) # dir
elif os.path.isfile(p):
files.append(p) # files
else:
raise FileNotFoundError(f"{p} does not exist")
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)
self.img_size = img_size
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
self.mode = "image"
self.auto = auto
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
if any(videos):
self._new_video(videos[0]) # new video
else:
self.cap = None
assert self.nf > 0, (
f"No images or videos found in {p}. Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
)
def __iter__(self):
"""Initializes iterator by resetting count and returns the iterator object itself."""
self.count = 0
return self
def __next__(self):
"""Advances to the next file in the dataset, raising StopIteration if at the end."""
if self.count == self.nf:
raise StopIteration
path = self.files[self.count]
if self.video_flag[self.count]:
# Read video
self.mode = "video"
for _ in range(self.vid_stride):
self.cap.grab()
ret_val, im0 = self.cap.retrieve()
while not ret_val:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration
path = self.files[self.count]
self._new_video(path)
ret_val, im0 = self.cap.read()
self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
else:
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
assert im0 is not None, f"Image Not Found {path}"
s = f"image {self.count}/{self.nf} {path}: "
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
return path, im, im0, self.cap, s
def _new_video(self, path):
"""Initializes a new video capture object with path, frame count adjusted by stride, and orientation
metadata.
"""
self.frame = 0
self.cap = cv2.VideoCapture(path)
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
def _cv2_rotate(self, im):
"""Rotates a cv2 image based on its orientation; supports 0, 90, and 180 degrees rotations."""
if self.orientation == 0:
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
elif self.orientation == 180:
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
elif self.orientation == 90:
return cv2.rotate(im, cv2.ROTATE_180)
return im
def __len__(self):
"""Returns the number of files in the dataset."""
return self.nf # number of files
class LoadStreams:
"""Loads and processes video streams for YOLOv5, supporting various sources including YouTube and IP cameras."""
def __init__(self, sources="file.streams", img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
"""Initializes a stream loader for processing video streams with YOLOv5, supporting various sources including
YouTube.
"""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = "stream"
self.img_size = img_size
self.stride = stride
self.vid_stride = vid_stride # video frame-rate stride
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
n = len(sources)
self.sources = [clean_str(x) for x in sources] # clean source names for later
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f"{i + 1}/{n}: {s}... "
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
check_requirements(("pafy", "youtube_dl==2020.12.2"))
import pafy
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0:
assert not is_colab(), "--source 0 webcam unsupported on Colab. Rerun command in a local environment."
assert not is_kaggle(), "--source 0 webcam unsupported on Kaggle. Rerun command in a local environment."
cap = cv2.VideoCapture(s)
assert cap.isOpened(), f"{st}Failed to open {s}"
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float("inf") # infinite stream fallback
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
_, self.imgs[i] = cap.read() # guarantee first frame
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
LOGGER.info("") # newline
# check for common shapes
s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
if not self.rect:
LOGGER.warning("WARNING β οΈ Stream shapes differ. For optimal performance supply similarly-shaped streams.")
def update(self, i, cap, stream):
"""Reads frames from stream `i`, updating imgs array; handles stream reopening on signal loss."""
n, f = 0, self.frames[i] # frame number, frame array
while cap.isOpened() and n < f:
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
if n % self.vid_stride == 0:
success, im = cap.retrieve()
if success:
self.imgs[i] = im
else:
LOGGER.warning("WARNING β οΈ Video stream unresponsive, please check your IP camera connection.")
self.imgs[i] = np.zeros_like(self.imgs[i])
cap.open(stream) # re-open stream if signal was lost
time.sleep(0.0) # wait time
def __iter__(self):
"""Resets and returns the iterator for iterating over video frames or images in a dataset."""
self.count = -1
return self
def __next__(self):
"""Iterates over video frames or images, halting on thread stop or 'q' key press, raising `StopIteration` when
done.
"""
self.count += 1
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord("q"): # q to quit
cv2.destroyAllWindows()
raise StopIteration
im0 = self.imgs.copy()
if self.transforms:
im = np.stack([self.transforms(x) for x in im0]) # transforms
else:
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
im = np.ascontiguousarray(im) # contiguous
return self.sources, im, im0, None, ""
def __len__(self):
"""Returns the number of sources in the dataset, supporting up to 32 streams at 30 FPS over 30 years."""
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
def img2label_paths(img_paths):
"""Generates label file paths from corresponding image file paths by replacing `/images/` with `/labels/` and
extension with `.txt`.
"""
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
class LoadImagesAndLabels(Dataset):
"""Loads images and their corresponding labels for training and validation in YOLOv5."""
cache_version = 0.6 # dataset labels *.cache version
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
def __init__(
self,
path,
img_size=640,
batch_size=16,
augment=False,
hyp=None,
rect=False,
image_weights=False,
cache_images=False,
single_cls=False,
stride=32,
pad=0.0,
min_items=0,
prefix="",
rank=-1,
seed=0,
):
"""Initializes the YOLOv5 dataset loader, handling images and their labels, caching, and preprocessing."""
self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
self.albumentations = Albumentations(size=img_size) if augment else None
try:
f = [] # image files
for p in path if isinstance(path, list) else [path]:
p = Path(p) # os-agnostic
if p.is_dir(): # dir
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
# f = list(p.rglob('*.*')) # pathlib
elif p.is_file(): # file
with open(p) as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
f += [x.replace("./", parent, 1) if x.startswith("./") else x for x in t] # to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # to global path (pathlib)
else:
raise FileNotFoundError(f"{prefix}{p} does not exist")
self.im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert self.im_files, f"{prefix}No images found"
except Exception as e:
raise Exception(f"{prefix}Error loading data from {path}: {e}\n{HELP_URL}") from e
# Check cache
self.label_files = img2label_paths(self.im_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix(".cache")
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache["version"] == self.cache_version # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except Exception:
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
assert nf > 0 or not augment, f"{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
# Read cache
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels, shapes, self.segments = zip(*cache.values())
nl = len(np.concatenate(labels, 0)) # number of labels
assert nl > 0 or not augment, f"{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
self.labels = list(labels)
self.shapes = np.array(shapes)
self.im_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
# Filter images
if min_items:
include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
LOGGER.info(f"{prefix}{n - len(include)}/{n} images filtered from dataset")
self.im_files = [self.im_files[i] for i in include]
self.label_files = [self.label_files[i] for i in include]
self.labels = [self.labels[i] for i in include]
self.segments = [self.segments[i] for i in include]
self.shapes = self.shapes[include] # wh
# Create indices
n = len(self.shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
self.n = n
self.indices = np.arange(n)
if rank > -1: # DDP indices (see: SmartDistributedSampler)
# force each rank (i.e. GPU process) to sample the same subset of data on every epoch
self.indices = self.indices[np.random.RandomState(seed=seed).permutation(n) % WORLD_SIZE == RANK]
# Update labels
include_class = [] # filter labels to include only these classes (optional)
self.segments = list(self.segments)
include_class_array = np.array(include_class).reshape(1, -1)
for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
if include_class:
j = (label[:, 0:1] == include_class_array).any(1)
self.labels[i] = label[j]
if segment:
self.segments[i] = [segment[idx] for idx, elem in enumerate(j) if elem]
if single_cls: # single-class training, merge all classes into 0
self.labels[i][:, 0] = 0
# Rectangular Training
if self.rect:
# Sort by aspect ratio
s = self.shapes # wh
ar = s[:, 1] / s[:, 0] # aspect ratio
irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect]
self.label_files = [self.label_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
self.segments = [self.segments[i] for i in irect]
self.shapes = s[irect] # wh
ar = ar[irect]
# Set training image shapes
shapes = [[1, 1]] * nb
for i in range(nb):
ari = ar[bi == i]
mini, maxi = ari.min(), ari.max()
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
# Cache images into RAM/disk for faster training
if cache_images == "ram" and not self.check_cache_ram(prefix=prefix):
cache_images = False
self.ims = [None] * n
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
if cache_images:
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == "disk" else self.load_image
results = ThreadPool(NUM_THREADS).imap(lambda i: (i, fcn(i)), self.indices)
pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache_images == "disk":
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
b += self.ims[i].nbytes * WORLD_SIZE
pbar.desc = f"{prefix}Caching images ({b / gb:.1f}GB {cache_images})"
pbar.close()
def check_cache_ram(self, safety_margin=0.1, prefix=""):
"""Checks if available RAM is sufficient for caching images, adjusting for a safety margin."""
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
n = min(self.n, 30) # extrapolate from 30 random images
for _ in range(n):
im = cv2.imread(random.choice(self.im_files)) # sample image
ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
b += im.nbytes * ratio**2
mem_required = b * self.n / n # GB required to cache dataset into RAM
mem = psutil.virtual_memory()
cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
if not cache:
LOGGER.info(
f"{prefix}{mem_required / gb:.1f}GB RAM required, "
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
f"{'caching images β
' if cache else 'not caching images β οΈ'}"
)
return cache
def cache_labels(self, path=Path("./labels.cache"), prefix=""):
"""Caches dataset labels, verifies images, reads shapes, and tracks dataset integrity."""
x = {} # dict
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning {path.parent / path.stem}..."
with Pool(NUM_THREADS) as pool:
pbar = tqdm(
pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
desc=desc,
total=len(self.im_files),
bar_format=TQDM_BAR_FORMAT,
)
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
x[im_file] = [lb, shape, segments]
if msg:
msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(f"{prefix}WARNING β οΈ No labels found in {path}. {HELP_URL}")
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
x["version"] = self.cache_version # cache version
try:
np.save(path, x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{prefix}New cache created: {path}")
except Exception as e:
LOGGER.warning(f"{prefix}WARNING β οΈ Cache directory {path.parent} is not writeable: {e}") # not writeable
return x
def __len__(self):
"""Returns the number of images in the dataset."""
return len(self.im_files)
# def __iter__(self):
# self.count = -1
# print('ran dataset iter')
# #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
# return self
def __getitem__(self, index):
"""Fetches the dataset item at the given index, considering linear, shuffled, or weighted sampling."""
index = self.indices[index] # linear, shuffled, or image_weights
hyp = self.hyp
if mosaic := self.mosaic and random.random() < hyp["mosaic"]:
# Load mosaic
img, labels = self.load_mosaic(index)
shapes = None
# MixUp augmentation
if random.random() < hyp["mixup"]:
img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))
else:
# Load image
img, (h0, w0), (h, w) = self.load_image(index)
# Letterbox
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
labels = self.labels[index].copy()
if labels.size: # normalized xywh to pixel xyxy format
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
if self.augment:
img, labels = random_perspective(
img,
labels,
degrees=hyp["degrees"],
translate=hyp["translate"],
scale=hyp["scale"],
shear=hyp["shear"],
perspective=hyp["perspective"],
)
nl = len(labels) # number of labels
if nl:
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)
if self.augment:
# Albumentations
img, labels = self.albumentations(img, labels)
nl = len(labels) # update after albumentations
# HSV color-space
augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
# Flip up-down
if random.random() < hyp["flipud"]:
img = np.flipud(img)
if nl:
labels[:, 2] = 1 - labels[:, 2]
# Flip left-right
if random.random() < hyp["fliplr"]:
img = np.fliplr(img)
if nl:
labels[:, 1] = 1 - labels[:, 1]
# Cutouts
# labels = cutout(img, labels, p=0.5)
# nl = len(labels) # update after cutout
labels_out = torch.zeros((nl, 6))
if nl:
labels_out[:, 1:] = torch.from_numpy(labels)
# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
return torch.from_numpy(img), labels_out, self.im_files[index], shapes
def load_image(self, i):
"""
Loads an image by index, returning the image, its original dimensions, and resized dimensions.
Returns (im, original hw, resized hw)
"""
im, f, fn = (
self.ims[i],
self.im_files[i],
self.npy_files[i],
)
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
assert im is not None, f"Image Not Found {f}"
h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
def cache_images_to_disk(self, i):
"""Saves an image to disk as an *.npy file for quicker loading, identified by index `i`."""
f = self.npy_files[i]
if not f.exists():
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
def load_mosaic(self, index):
"""Loads a 4-image mosaic for YOLOv5, combining 1 selected and 3 random images, with labels and segments."""
labels4, segments4 = [], []
s = self.img_size
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
random.shuffle(indices)
for i, index in enumerate(indices):
# Load image
img, _, (h, w) = self.load_image(index)
# place img in img4
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
padw = x1a - x1b
padh = y1a - y1b
# Labels
labels, segments = self.labels[index].copy(), self.segments[index].copy()
if labels.size:
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
labels4.append(labels)
segments4.extend(segments)
# Concat/clip labels
labels4 = np.concatenate(labels4, 0)
for x in (labels4[:, 1:], *segments4):
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
# img4, labels4 = replicate(img4, labels4) # replicate
# Augment
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp["copy_paste"])
img4, labels4 = random_perspective(
img4,
labels4,
segments4,
degrees=self.hyp["degrees"],
translate=self.hyp["translate"],
scale=self.hyp["scale"],
shear=self.hyp["shear"],
perspective=self.hyp["perspective"],
border=self.mosaic_border,
) # border to remove
return img4, labels4
def load_mosaic9(self, index):
"""Loads 1 image + 8 random images into a 9-image mosaic for augmented YOLOv5 training, returning labels and
segments.
"""
labels9, segments9 = [], []
s = self.img_size
indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
random.shuffle(indices)
hp, wp = -1, -1 # height, width previous
for i, index in enumerate(indices):
# Load image
img, _, (h, w) = self.load_image(index)
# place img in img9
if i == 0: # center
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
h0, w0 = h, w
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
elif i == 1: # top
c = s, s - h, s + w, s
elif i == 2: # top right
c = s + wp, s - h, s + wp + w, s
elif i == 3: # right
c = s + w0, s, s + w0 + w, s + h
elif i == 4: # bottom right
c = s + w0, s + hp, s + w0 + w, s + hp + h
elif i == 5: # bottom
c = s + w0 - w, s + h0, s + w0, s + h0 + h
elif i == 6: # bottom left
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
elif i == 7: # left
c = s - w, s + h0 - h, s, s + h0
elif i == 8: # top left
c = s - w, s + h0 - hp - h, s, s + h0 - hp
padx, pady = c[:2]
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
# Labels
labels, segments = self.labels[index].copy(), self.segments[index].copy()
if labels.size:
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
labels9.append(labels)
segments9.extend(segments)
# Image
img9[y1:y2, x1:x2] = img[y1 - pady :, x1 - padx :] # img9[ymin:ymax, xmin:xmax]
hp, wp = h, w # height, width previous
# Offset
yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
img9 = img9[yc : yc + 2 * s, xc : xc + 2 * s]
# Concat/clip labels
labels9 = np.concatenate(labels9, 0)
labels9[:, [1, 3]] -= xc
labels9[:, [2, 4]] -= yc
c = np.array([xc, yc]) # centers
segments9 = [x - c for x in segments9]
for x in (labels9[:, 1:], *segments9):
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()