-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
67 lines (61 loc) · 2.55 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from torch.utils.data import Dataset
from PIL import Image
class DogDataset(Dataset):
def __init__(self, root_dir, dataset_type, transform=None):
self.root_dir = root_dir
self.dataset_type = dataset_type
self.transform = transform
self.classes = ['Beagle', 'Siberian Husky', 'Toy Poodle']
self.class_to_folder = {
'Beagle': 'Beagle',
'Siberian Husky': 'Siberian Husky',
'Toy Poodle': 'Toy Poodle'
}
self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
self.idx_to_class = {i: c for i, c in enumerate(self.classes)}
self.data = []
for c in self.classes:
class_dir = os.path.join(self.root_dir, 'dogs', 'Images', self.dataset_type, self.class_to_folder[c])
for fname in os.listdir(class_dir):
if fname.endswith('.jpg'):
path = os.path.join(class_dir, fname)
item = (path, self.class_to_idx[c])
self.data.append(item)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
path, label = self.data[idx]
image = Image.open(path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
class CatDataset(Dataset):
def __init__(self, root_dir, dataset_type, transform=None):
self.root_dir = root_dir
self.dataset_type = dataset_type
self.transform = transform
self.classes = ['Siamese', 'Sphynx', 'Bengal']
self.class_to_folder = {
'Siamese': 'Siamese',
'Sphynx': 'Sphynx',
'Bengal': 'Bengal'
}
self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
self.idx_to_class = {i: c for i, c in enumerate(self.classes)}
self.data = []
for c in self.classes:
class_dir = os.path.join(self.root_dir, 'cats', 'Images', self.dataset_type, self.class_to_folder[c])
for fname in os.listdir(class_dir):
if fname.endswith('.jpg'):
path = os.path.join(class_dir, fname)
item = (path, self.class_to_idx[c])
self.data.append(item)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
path, label = self.data[idx]
image = Image.open(path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label