Skip to content

Commit

Permalink
Merge pull request #221 from KMarshallX:update_20240429
Browse files Browse the repository at this point in the history
Update_20240429
  • Loading branch information
KMarshallX authored Apr 29, 2024
2 parents 46e1df3 + 932e9cb commit 8d18aa0
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 27 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ readme_img/Snipaste_2023-03-15_09-11-34.jpg
/config/__pycache__
config/train_test_config.py

*.json
*.npy
14 changes: 14 additions & 0 deletions utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ def out_file_reader(file_path):
plt.grid(True)
plt.show()

def cv_helper(ps_path):

if any(entry.is_dir() for entry in os.scandir(ps_path) if entry.is_dir()):
raise ValueError("The image directory contains subdirs")

# make sure the image path and seg path contains equal number of files
raw_file_list = os.listdir(ps_path)

cv_dict = {}
for i, current_file in enumerate(raw_file_list):
cv_dict.__setitem__(current_file, raw_file_list[: i] + raw_file_list[i + 1 :])

return cv_dict

def mra_deskull(img_path, msk_path, mip_flag):
"""
Apply a mask on the target nifti image and generate an MIP image.
Expand Down
12 changes: 8 additions & 4 deletions utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def normaliser(self, x):
def sigmoid(self, z):
return 1/(1+np.exp(-z))

def nn_sigmoid(self, z):
sigmoid_fn = torch.nn.Sigmoid()
return sigmoid_fn(z)

def inference(self, test_patches, load_model, ori_size):
print("Prediction procedure starts!")
# Predict each 3D patch
Expand All @@ -125,18 +129,18 @@ def inference(self, test_patches, load_model, ori_size):

single_patch = test_patches[i,j,k, :,:,:]
single_patch_input = single_patch[None, :]
single_patch_input = torch.from_numpy(single_patch_input).type(torch.FloatTensor).unsqueeze(0) # type: ignore
single_patch_input = torch.from_numpy(single_patch_input).type(torch.FloatTensor).unsqueeze(0)

single_patch_prediction = load_model(single_patch_input)
single_patch_prediction = self.nn_sigmoid(load_model(single_patch_input))

single_patch_prediction_out = single_patch_prediction.detach().numpy()[0,0,:,:,:]

test_patches[i,j,k, :,:,:] = single_patch_prediction_out

test_output = unpatchify(test_patches, (ori_size[0], ori_size[1], ori_size[2]))
test_output_sigmoid = self.sigmoid(test_output)

print("Prediction procedure ends! Please wait for the post processing!")
return test_output_sigmoid
return test_output

def post_processing_pipeline(self, arr, percent, connect_threshold):
"""
Expand Down
48 changes: 42 additions & 6 deletions utils/single_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import nibabel as nib
import os

from .unet_utils import RandomCrop3D, standardiser
from .unet_utils import RandomCrop3D, standardiser, normaliser

class single_channel_loader:
def __init__(self, raw_img, seg_img, patch_size, step):
def __init__(self, raw_img, seg_img, patch_size, step, test_mode=False):
"""
:param raw_img: str, path of the raw file
:param seg_img: str, path of the label file
Expand All @@ -30,6 +30,8 @@ def __init__(self, raw_img, seg_img, patch_size, step):
self.patch_size = patch_size
self.step = step

self.test_mode = test_mode

def __repr__(self):
return f"Processing image {self.raw_img} and its segmentation {self.seg_img}\n"

Expand All @@ -45,12 +47,12 @@ def __iter__(self):
assert (raw_size[2] == seg_size[2]), "Input image and segmentation dimension not matched, 2"

for i in range(self.step):
cropper = RandomCrop3D(raw_size, self.patch_size)
cropper = RandomCrop3D(raw_size, self.patch_size, self.test_mode)
img_crop, seg_crop = cropper(self.raw_arr, self.seg_arr)

yield img_crop, seg_crop

def multi_channel_loader(ps_path, seg_path, patch_size, step):
def multi_channel_loader(ps_path, seg_path, patch_size, step, test_mode=False):
"""
Loads multiple images and their corresponding segmentation masks from the given paths.
Args:
Expand Down Expand Up @@ -81,6 +83,40 @@ def multi_channel_loader(ps_path, seg_path, patch_size, step):
break
assert (seg_img_name != None), f"There is no corresponding label to {raw_file_list[i]}!"
# a linked hashmap to store the provoked data loaders
loaders_dict.__setitem__(i, single_channel_loader(raw_img_name, seg_img_name, patch_size, step))
loaders_dict.__setitem__(i, single_channel_loader(raw_img_name, seg_img_name, patch_size, step, test_mode))

return loaders_dict

def cv_multi_channel_loader(ps_path, seg_path, train_list, patch_size, step, test_mode=False):
"""
Loads multiple images and their corresponding segmentation masks from the given paths.
Args:
ps_path (str): Path to the folder containing the processed images.
seg_path (str): Path to the folder containing the label images.
train_list (list): A list containing the indices of the images to be used for training.
patch_size (tuple): Size of the patches to be extracted from the images.
step (int): Step size for the sliding window approach to extract patches.
Returns:
dict: A dictionary containing the initialized single_channel_loaders for each image.
"""
# make sure the image path and seg path contains equal number of files
raw_file_list = os.listdir(ps_path)
seg_file_list = os.listdir(seg_path)
assert (len(raw_file_list) == len(seg_file_list)), "Number of images and correspinding segs not matched!"

# initialize single_channel_loaders for each image
# and store the initialized loaders in a linked hashmaps
loaders_dict = dict()
for i in range(len(train_list)):
# joined path to the current image file
raw_img_name = os.path.join(ps_path, train_list[i])
# find the corresponding seg file in the seg_folder
seg_img_name = None
for j in range(len(seg_file_list)):
if seg_file_list[j].find(train_list[i].split('.')[0]) != -1:
seg_img_name = os.path.join(seg_path, seg_file_list[j])
break
assert (seg_img_name != None), f"There is no corresponding label to {raw_file_list[i]}!"
loaders_dict.__setitem__(i, single_channel_loader(raw_img_name, seg_img_name, patch_size, step, test_mode))

return loaders_dict
return loaders_dict
38 changes: 33 additions & 5 deletions utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch
from tqdm import tqdm
from .unet_utils import *
from .single_data_loader import single_channel_loader, multi_channel_loader
from .eval_utils import cv_helper
from .single_data_loader import single_channel_loader, multi_channel_loader, cv_multi_channel_loader
from .module_utils import prediction_and_postprocess
from models import Unet, ASPPCNN, CustomSegmentationNetwork, MainArchitecture

Expand Down Expand Up @@ -94,7 +95,8 @@ def __init__(self, loss_name, model_name,
batch_mul,
patch_size, augmentation_mode,
pretrained_model = None,
thresh = None, connect_thresh = None):
thresh = None, connect_thresh = None,
test_mode = False):
# type of the loss metric
self.loss_name = loss_name
# type of the model
Expand All @@ -119,6 +121,8 @@ def __init__(self, loss_name, model_name,
self.pretrained_model = pretrained_model
# hardware config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# test mode
self.test_mode = test_mode

def loss_init(self):
return loss_metric(self.loss_name)
Expand All @@ -131,7 +135,10 @@ def scheduler_init(self, optimizer):
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = self.optim_gamma, patience = optim_patience)

def aug_init(self):
return aug_utils(self.aug_config[0], self.aug_config[1])
if not self.test_mode:
return aug_utils(self.aug_config[0], "on")
else: # test_mode = True
return aug_utils(self.aug_config[0], "off")

def pretrained_model_loader(self):
load_model = self.model_init()
Expand Down Expand Up @@ -196,15 +203,36 @@ def training_loop(self, data_loader, model, save_path):
def train(self, ps_path, seg_path, out_mo_path):
# initialize the data loader
step = int(self.epoch_num * self.batch_mul)
multi_image_loder = multi_channel_loader(ps_path, seg_path, self.aug_config[0], step)
multi_image_loder = multi_channel_loader(ps_path, seg_path, self.aug_config[0], step, self.test_mode)
# initialize the model
model = self.model_init()

print(f"\nIn this test, the batch size is {6 * self.batch_mul}\n")
# print(f"\nIn this test, the batch size is {6 * self.batch_mul}\n")

# training loop
self.training_loop(multi_image_loder, model, out_mo_path)

def cross_valid_train(self, ps_path, seg_path, model_path):
cv_dict = cv_helper(ps_path)
cnt = 0
print(f"Total {len(cv_dict)} will be generated!\n")
for key, value in cv_dict.items():
cnt += 1
print(f"Cross validation {cnt} will start shortly!\n Test image is {key}\n")
# initialize the data loader
step = int(self.epoch_num * self.batch_mul)
multi_image_loder = cv_multi_channel_loader(ps_path, seg_path, value, self.aug_config[0], step, self.test_mode)
# initialize the model
model = self.model_init()
# out model path
test_name = key.split('.')[0]
if os.path.exists(model_path) == False:
os.makedirs(model_path)
print(f"{model_path}doesn't exists! {model_path} has been created!")
out_mo_path = os.path.join(model_path, f"cv_{cnt}_{test_name}")
# training loop
self.training_loop(multi_image_loder, model, out_mo_path)

def test_time_adaptation(self, ps_path, px_path, out_path, out_mo_path, resource_opt):
# traverse each image
processed_data_list = os.listdir(ps_path)
Expand Down
50 changes: 38 additions & 12 deletions utils/unet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,36 @@ def __call__(self, input, segin):
segin = self.zooming(segin)

if self.mode == "on":
# print("Aug mode on, rotation/flip only")
input_batch = np.stack((input, self.rot(input, 1), self.rot(input, 2), self.rot(input, 3),
self.flip_hr(input, 1), self.flip_vt(input, 1)), axis=0)
segin_batch = np.stack((segin, self.rot(segin, 1), self.rot(segin, 2), self.rot(segin, 3),
self.flip_hr(segin, 1), self.flip_vt(segin, 1)), axis=0)
elif self.mode == "off":
elif self.mode == "repeat":
# print("Aug mode repeat, repeat the same patch 6 times")
input_batch = np.stack((input, input, input, input, input, input), axis=0)
segin_batch = np.stack((segin, segin, segin, segin, segin, segin), axis=0)
elif self.mode == "mode1":
# print("Aug mode 1, rotation & blurring")
input_batch = np.stack((input, self.rot(input, 1), self.rot(input, 2), self.rot(input, 3), self.filter(input, 2), self.filter(input, 3)), axis=0)
segin_batch = np.stack((segin, self.rot(segin, 1), self.rot(segin, 2), self.rot(segin, 3), segin, segin), axis=0)
elif self.mode == "test":

elif self.mode == "mode2":
# print("Aug mode 2, only one patch, with blurring effect")
input_batch = np.expand_dims(self.filter(input, 2), axis=0)
segin_batch = np.expand_dims(self.filter(segin, 2), axis=0)
elif self.mode == "mode3":
ind = np.random.randint(0, 2)
if ind == 0:
k = np.random.randint(1, 4)
print("Aug mode 3, rotate")
input_batch = np.expand_dims(self.rot(input, k), axis=0)
segin_batch = np.expand_dims(self.rot(segin, k), axis=0)
elif ind == 1:
print("Aug mode 3, blur")
input_batch = np.expand_dims(self.filter(input, 2), axis=0)
segin_batch = np.expand_dims(self.filter(segin, 2), axis=0)
elif self.mode == "off":
print("Aug mode off")
input_batch = np.expand_dims(input, axis=0)
segin_batch = np.expand_dims(segin, axis=0)

Expand Down Expand Up @@ -242,21 +260,29 @@ class RandomCrop3D():
"""
Resample the input image slab by randomly cropping a 3D volume, and reshape to a fixed size e.g.(64,64,64)
"""
def __init__(self, img_sz, exp_sz):
def __init__(self, img_sz, exp_sz, test_mode=False, test_crop_sz=(64,64,64)):
h, w, d = img_sz
# test 0925, constraint the higher bound of the crop size to be 128
crop_h = torch.randint(32, h, (1,)).item()
crop_w = torch.randint(32, w, (1,)).item()
crop_d = torch.randint(32, d, (1,)).item()
assert (h, w, d) > (crop_h, crop_w, crop_d)
self.test_mode = test_mode
if not test_mode:
crop_h = torch.randint(32, h, (1,)).item()
crop_w = torch.randint(32, w, (1,)).item()
crop_d = torch.randint(32, d, (1,)).item()
assert (h, w, d) > (crop_h, crop_w, crop_d)
self.crop_sz = tuple((crop_h, crop_w, crop_d))
else:
self.crop_sz = test_crop_sz
self.img_sz = tuple((h, w, d))
self.crop_sz = tuple((crop_h, crop_w, crop_d))
self.exp_sz = exp_sz

def __call__(self, img, lab):
slice_hwd = [self._get_slice(i, k) for i, k in zip(self.img_sz, self.crop_sz)]
return scind.zoom(self._crop(img, *slice_hwd),(self.exp_sz[0]/self.crop_sz[0], self.exp_sz[1]/self.crop_sz[1], self.exp_sz[2]/self.crop_sz[2]), order=0, mode='nearest'), scind.zoom(self._crop(lab, *slice_hwd),(self.exp_sz[0]/self.crop_sz[0], self.exp_sz[1]/self.crop_sz[1], self.exp_sz[2]/self.crop_sz[2]), order=0, mode='nearest')

if not self.test_mode:
print("TEST MODE: The patch are zoomed")
return scind.zoom(self._crop(img, *slice_hwd),(self.exp_sz[0]/self.crop_sz[0], self.exp_sz[1]/self.crop_sz[1], self.exp_sz[2]/self.crop_sz[2]), order=0, mode='nearest'), scind.zoom(self._crop(lab, *slice_hwd),(self.exp_sz[0]/self.crop_sz[0], self.exp_sz[1]/self.crop_sz[1], self.exp_sz[2]/self.crop_sz[2]), order=0, mode='nearest')
else:
print("TEST MODE: The patch hasn't been zoomed")
return self._crop(img, *slice_hwd), self._crop(lab, *slice_hwd)

@staticmethod
def _get_slice(sz, crop_sz):
try :
Expand Down

0 comments on commit 8d18aa0

Please sign in to comment.