-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
100 lines (73 loc) · 2.34 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import csv
import lmdb
import random
import numpy as np
import torchvision.transforms.functional as TF
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
import glob
import torchvision.transforms as transforms
import PIL
class FFHQFake(Dataset):
"""CelelebA Dataset"""
def __init__(self, dataset_path, transform, resolution, csvfile='', **kwargs):
super().__init__()
print(csvfile)
self.dataset_path = dataset_path
# self.data = glob.glob(dataset_path)
#
# assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset"
self.transform = transform
self.resolution = resolution
with open(csvfile, newline='') as f:
reader = csv.reader(f)
data = list(reader)
self.label_list = []
self.style_list = []
#ratio
# hair color
# label1 32019
# label2 7981
# gender
# label1 9254 # 3.32
# label2 30746
# bangs
# label1 37992 # 0.05
# label2 2008
# Age
# label1 13476 # 1.96
# label2 26524
# smile
# label1 15718 # 1.54
# label2 24282
# bears:
# label1 13915 # 1.87
# label2 26085
for j in range(len(data)):
item_list = []
items = data[j][0].split(' ')
for i in range(len(items)):
if i not in [0]:
if float(items[i]) > 0.8:
item_list.append(1)
else:
item_list.append(0)
self.label_list.append(item_list)
self.style_list.append(items[0])
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
imgname = self.style_list[index]
datapath = self.dataset_path.split('*.png')[0]
imgpath = os.path.join(datapath, imgname)
styles_path = imgpath.split('.png')[0] + '_ws.npy'
labels = np.array(self.label_list[index])
styles = np.load(styles_path).squeeze()
X = PIL.Image.open(imgpath)
img = self.transform(X)
# pose
pose_path = imgpath.split('.png')[0] + '_pose.npy'
poses = np.load(pose_path).squeeze()
return img, labels, styles, poses