-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path1-generate_dataset.py
63 lines (50 loc) · 2.24 KB
/
1-generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import tensorflow as tf
from sklearn.model_selection import train_test_split
import cv2
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def image_to_tfexample(image_data, label):
feature = {
"image_raw": _bytes_feature(image_data),
"label": _int64_feature(label)
}
return tf.train.Example(features=tf.train.Features(feature=feature))
def load_images_and_labels(folder_path):
images = []
labels = []
sub_dirs = [d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))]
label_mapping = {d: ord(d) for i, d in enumerate(sub_dirs)}
for label_str, label_num in label_mapping.items():
digit_folder_path = os.path.join(folder_path, label_str)
if os.path.isdir(digit_folder_path):
for file_name in os.listdir(digit_folder_path):
if file_name.endswith(".png"):
file_path = os.path.join(digit_folder_path, file_name)
image = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
images.append(image)
labels.append(label_num)
return images, labels
def write_tfrecords(images, labels, output_file):
with tf.io.TFRecordWriter(output_file) as writer:
for image, label in zip(images, labels):
image_data = image.tobytes()
example = image_to_tfexample(image_data, label)
writer.write(example.SerializeToString())
def main():
images, labels = load_images_and_labels("processed")
# Split into training and test sets
images_train, images_test, labels_train, labels_test = train_test_split(
images, labels, test_size=0.2, random_state=42
)
# Generate TFRecord files
os.makedirs("./dataset", exist_ok=True)
write_tfrecords(images_train, labels_train, "./dataset/train.tfrecords")
write_tfrecords(images_test, labels_test, "./dataset/test.tfrecords")
print("TFRecord files generated successfully!")
if __name__ == "__main__":
main()