diff --git a/.gitignore b/.gitignore index 1d4a153..91c1cfd 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ config/train_test_config.py *.json *.npy + +t2boost.py diff --git a/angiboost.py b/angiboost.py new file mode 100644 index 0000000..851c989 --- /dev/null +++ b/angiboost.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +""" +Angio Boost module - train a model on single subject from scratch, then make prediction + +Editor: Marshall Xu +Last Edited: 04/10/2024 +""" + +import config.angiboost_config as angiboost_config +from utils import preprocess_procedure, make_prediction +from utils import TTA_Training +import os + +args = angiboost_config.args +# input images & labels +ds_path = args.ds_path # needed as input argument +ps_path = args.ps_path +seg_path = args.lb_path # needed as input argument (places to store the initial segmentation) +prep_mode = args.prep_mode # needed as input argument +outmo_path = args.outmo # needed as input argument +out_path = args.out_path # needed as input argument +pretrained = args.pretrained # needed as input argument + +if os.path.exists(seg_path) == False: + print(f"{seg_path} does not exist.") + os.mkdir(seg_path) + print(f"{seg_path} has been created!") + +if os.path.exists(out_path) == False: + print(f"{out_path} does not exist.") + os.mkdir(out_path) + print(f"{out_path} has been created!") + +# when the preprocess is skipped, +# directly take the raw data for prediction +if prep_mode == 4: + ps_path = ds_path + +if __name__ == "__main__": + print("Boosting session will start shortly..") + print("Parameters Info:\n*************************************************************\n") + print(f"Input image path: {ds_path}, Segmentation path: {seg_path}, Prep_mode: {prep_mode}\n") + print(f"Epoch number: {args.ep}, Learning rate: {args.lr} \n") + + # preprocess procedure + preprocess_procedure(ds_path, ps_path, prep_mode) + + # genereate the initial segmentation + make_prediction(args.mo, args.ic, args.oc, + args.fil, ps_path, seg_path, + args.thresh, args.cc, pretrained, + mip_flag=False) + + # initialize the training process + train_process = TTA_Training(args.loss_m, args.mo, + args.ic, args.oc, args.fil, + args.op, args.lr, + args.optim_gamma, args.ep, + args.batch_mul, + args.osz, args.aug_mode) + + # traning loop (this could be separate out ) + train_process.train(ps_path, seg_path, outmo_path) + + # make prediction + make_prediction(args.mo, args.ic, args.oc, + args.fil, ps_path, out_path, + args.thresh, args.cc, outmo_path, + mip_flag=True) + + print(f"Boosting session has been completed! Resultant segmentation has been saved to {out_path}.") + diff --git a/config/angiboost_config.py b/config/angiboost_config.py new file mode 100644 index 0000000..bb86b54 --- /dev/null +++ b/config/angiboost_config.py @@ -0,0 +1,67 @@ +""" +argparse configuration for angiboost.py (openrecon) + +Editor: Marshall Xu +Last Edited: 04/10/2024 +""" +#@TODO: remember change the optim_patience + +import argparse + +angiboost_parser = argparse.ArgumentParser(description="VesselBoost AngiBoost arguments") + +# Train.py +# input /output (train.py) +angiboost_parser.add_argument('--ds_path', default = "/ds_path/", help="input image path") +angiboost_parser.add_argument('--lb_path', default = "/lb_path/", help="initially generated label path") +angiboost_parser.add_argument('--ps_path', type=str, default = "/preprocessed_path/", help="path of the preprocessed data") +angiboost_parser.add_argument('--out_path', type=str, default = "/out_path/", help="path of the output segmentation") +angiboost_parser.add_argument('--pretrained', type=str, default = "/pretrained_model_path/", help="path of the prertrained model") + +# preprocessing mode +angiboost_parser.add_argument('--prep_mode', type=int, default=4, help="Preprocessing mode options. prep_mode=1 : bias field correction only | prep_mode=2 : denoising only | prep_mode=3 : bfc + denoising | prep_mode=4 : no preprocessing applied") + +# The following needs to be changed manually (for now) +angiboost_parser.add_argument('--outmo', default = "./saved_models/model", help="output model path, e.g. ./saved_models/xxxxx") + +# model configuration +# model name, available: [unet3d, aspp, atrous] +angiboost_parser.add_argument('--mo', type=str, default="unet3d", help=argparse.SUPPRESS) +# input channel for unet 3d +angiboost_parser.add_argument('--ic', type=int, default=1, help=argparse.SUPPRESS) +# output channel for unet 3d +angiboost_parser.add_argument('--oc', type=int, default=1, help=argparse.SUPPRESS) +# number of filters for each layer in unet 3d +angiboost_parser.add_argument('--fil', type=int, default=16, help=argparse.SUPPRESS) + +# Training configurations +angiboost_parser.add_argument('--lr', type=float, default=1e-3, help="learning rate, dtype: float, default=1e-3") +angiboost_parser.add_argument('--ep', type=int, default=1000, help="epoch number (times of iteration), dtype: int, default=16") + +# expected size after zooming +angiboost_parser.add_argument('--osz', type=tuple, default=(64,64,64), help=argparse.SUPPRESS) +# optimizer type, available: [sgd, adam] +angiboost_parser.add_argument('--op', type=str, default="adam", help=argparse.SUPPRESS) +# loss metric type, available: [bce, dice, tver] +angiboost_parser.add_argument('--loss_m', type=str, default="tver", help=argparse.SUPPRESS) + +# Optimizer tuning +# Decays the learning rate of each parameter group by this ratio, dtype: float +angiboost_parser.add_argument('--optim_gamma', type=float, default=0.95, help=argparse.SUPPRESS) +# Number of steps with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 steps with no improvement, and will only decrease the LR after the 3rd step if the loss still hasn’t improved then. Default: 10. +# Discarded feature (06/20/2023) +# angiboost_parser.add_argument('--optim_patience', type=int, default=100, help=argparse.SUPPRESS) + +# Augmentation mode, available : [on, off, test, mode1] +angiboost_parser.add_argument('--aug_mode', type=str, default="mode1", help=argparse.SUPPRESS) + +# batch size multiplier +angiboost_parser.add_argument('--batch_mul', type=int, default=4, help=argparse.SUPPRESS) + +# postprocessing / threshold +# hard thresholding value +angiboost_parser.add_argument('--thresh', type=float, default=0.1, help="binary threshold for the probability map after prediction, default=0.1") +# connected components analysis threshold value (denoising) +angiboost_parser.add_argument('--cc', type=int, default=10, help="connected components analysis threshold value (denoising), default=10") + +args = angiboost_parser.parse_args() \ No newline at end of file diff --git a/utils/module_utils.py b/utils/module_utils.py index b96c851..96fc018 100644 --- a/utils/module_utils.py +++ b/utils/module_utils.py @@ -211,8 +211,8 @@ def one_img_process(self, img_name, load_model, thresh, connect_thresh, mip_flag # save the maximum intensity projection as jpg file if mip_flag == True: mip = np.max(postprocessed_output, axis=2) - # save_mip_path_post = os.path.join(self.output_path, img_name.split('.')[0], ".jpg") - save_mip_path_post = self.output_path + img_name.split('.')[0] + ".jpg" + save_mip_path_post = os.path.join(self.output_path, img_name.split('.')[0]) + ".jpg" + # save_mip_path_post = self.output_path + img_name.split('.')[0] + ".jpg" #rotate the mip 90 degrees, counterclockwise mip = np.rot90(mip, axes=(0, 1)) plt.imsave(save_mip_path_post, mip, cmap='gray')