-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathU-Net_model_to_be_investigated
100 lines (76 loc) · 3.94 KB
/
U-Net_model_to_be_investigated
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from keras_preprocessing.image import ImageDataGenerator
from tensorflow import keras
#Define constants
SEED = 909
BATCH_SIZE_TRAIN = 8
BATCH_SIZE_TEST = 8
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
IMG_SIZE = (IMAGE_HEIGHT, IMAGE_WIDTH)
data_dir = '/home/data'
data_dir_train = os.path.join(data_dir, 'training')
data_dir_train_image = os.path.join(data_dir_train, 'img')
data_dir_train_mask = os.path.join(data_dir_train, 'mask')
data_dir_test = os.path.join(data_dir, 'test')
data_dir_test_image = os.path.join(data_dir_test, 'img')
data_dir_test_mask = os.path.join(data_dir_test, 'mask')
NUM_TRAIN = 1413
NUM_TEST = 210
NUM_OF_EPOCHS = 10
def create_segmentation_generator_train(img_path, mask_path, BATCH_SIZE):
data_gen_args = dict(rescale=1./255)
img_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(*data_gen_args)
img_generator = img_datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
mask_generator = mask_datagen.flow_from_directory(mask_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
return zip(img_generator, mask_generator)
def create_segmentation_generator_test(img_path, mask_path, BATCH_SIZE):
data_gen_args = dict(rescale=1./255)
img_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(*data_gen_args)
img_generator = img_datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
mask_generator = mask_datagen.flow_from_directory(mask_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
return zip(img_generator, mask_generator)
def display(display_list):
plt.figure(figsize=(15,15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]), cmap='gray')
plt.show()
def unet(n_levels, initial_features=32, n_blocks=2, kernel_size=3, pooling_size=2, in_channels=1, out_channels=1):
#n_blocks = how many conv in each level
inputs = keras.layers.Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, in_channels))
x = inputs
convpars = dict(kernel_size=kernel_size, activation='relu', padding='same')
#downstream
skips = {}
for level in range(n_levels):
for _ in range (n_blocks):
x = keras.layers.Conv2D(initial_features * 2 ** level, **convpars)(x)
if level < n_levels - 1:
skips[level] = x
x = keras.layers.MaxPool2D(pooling_size)(x)
#upstream
for level in reversed(range(n_levels-1)):
x = keras.layers.Conv2DTranspose(initial_features * 2 ** level, strides=pooling_size, **convpars)(x)
x = keras.layers.Concatenate()([x, skips[level]])
for _ in range (n_blocks):
x = keras.layers.Conv2D(initial_features * 2 ** level, **convpars)(x)
#output
activation = 'sigmoid' if out_channels == 1 else 'softmax'
x = keras.layers.Conv2D(out_channels, kernel_size=1, activation='sigmoid', padding='same')(x)
return keras.Model(inputs=[inputs], outputs=[x], name=f'UNET-L{n_levels}-F{initial_features}')
EPOCH_STEP_TRAIN = NUM_TRAIN // BATCH_SIZE_TRAIN
EPOCH_STEP_TEST = NUM_TEST // BATCH_SIZE_TRAIN
model = unet(4)
model.compile(optimizer=tf.keras.optimizers.Adam(clipvalue=0.5), loss='binary_crossentropy', metrics=['accuracy'])
model.fit_generator(generator=train_generator, steps_per_epoch=EPOCH_STEP_TRAIN, validation_data=test_generator, validation_steps=EPOCH_STEP_TEST, epochs=NUM_OF_EPOCHS)