Skip to content

Commit

Permalink
Merge pull request #222 from KMarshallX/openrecon
Browse files Browse the repository at this point in the history
openrecon branch
  • Loading branch information
KMarshallX authored Oct 4, 2024
2 parents dac79bd + db18bda commit ca62c4e
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ config/train_test_config.py

*.json
*.npy

t2boost.py
73 changes: 73 additions & 0 deletions angiboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3

"""
Angio Boost module - train a model on single subject from scratch, then make prediction
Editor: Marshall Xu
Last Edited: 04/10/2024
"""

import config.angiboost_config as angiboost_config
from utils import preprocess_procedure, make_prediction
from utils import TTA_Training
import os

args = angiboost_config.args
# input images & labels
ds_path = args.ds_path # needed as input argument
ps_path = args.ps_path
seg_path = args.lb_path # needed as input argument (places to store the initial segmentation)
prep_mode = args.prep_mode # needed as input argument
outmo_path = args.outmo # needed as input argument
out_path = args.out_path # needed as input argument
pretrained = args.pretrained # needed as input argument

if os.path.exists(seg_path) == False:
print(f"{seg_path} does not exist.")
os.mkdir(seg_path)
print(f"{seg_path} has been created!")

if os.path.exists(out_path) == False:
print(f"{out_path} does not exist.")
os.mkdir(out_path)
print(f"{out_path} has been created!")

# when the preprocess is skipped,
# directly take the raw data for prediction
if prep_mode == 4:
ps_path = ds_path

if __name__ == "__main__":
print("Boosting session will start shortly..")
print("Parameters Info:\n*************************************************************\n")
print(f"Input image path: {ds_path}, Segmentation path: {seg_path}, Prep_mode: {prep_mode}\n")
print(f"Epoch number: {args.ep}, Learning rate: {args.lr} \n")

# preprocess procedure
preprocess_procedure(ds_path, ps_path, prep_mode)

# genereate the initial segmentation
make_prediction(args.mo, args.ic, args.oc,
args.fil, ps_path, seg_path,
args.thresh, args.cc, pretrained,
mip_flag=False)

# initialize the training process
train_process = TTA_Training(args.loss_m, args.mo,
args.ic, args.oc, args.fil,
args.op, args.lr,
args.optim_gamma, args.ep,
args.batch_mul,
args.osz, args.aug_mode)

# traning loop (this could be separate out )
train_process.train(ps_path, seg_path, outmo_path)

# make prediction
make_prediction(args.mo, args.ic, args.oc,
args.fil, ps_path, out_path,
args.thresh, args.cc, outmo_path,
mip_flag=True)

print(f"Boosting session has been completed! Resultant segmentation has been saved to {out_path}.")

67 changes: 67 additions & 0 deletions config/angiboost_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
argparse configuration for angiboost.py (openrecon)
Editor: Marshall Xu
Last Edited: 04/10/2024
"""
#@TODO: remember change the optim_patience

import argparse

angiboost_parser = argparse.ArgumentParser(description="VesselBoost AngiBoost arguments")

# Train.py
# input /output (train.py)
angiboost_parser.add_argument('--ds_path', default = "/ds_path/", help="input image path")
angiboost_parser.add_argument('--lb_path', default = "/lb_path/", help="initially generated label path")
angiboost_parser.add_argument('--ps_path', type=str, default = "/preprocessed_path/", help="path of the preprocessed data")
angiboost_parser.add_argument('--out_path', type=str, default = "/out_path/", help="path of the output segmentation")
angiboost_parser.add_argument('--pretrained', type=str, default = "/pretrained_model_path/", help="path of the prertrained model")

# preprocessing mode
angiboost_parser.add_argument('--prep_mode', type=int, default=4, help="Preprocessing mode options. prep_mode=1 : bias field correction only | prep_mode=2 : denoising only | prep_mode=3 : bfc + denoising | prep_mode=4 : no preprocessing applied")

# The following needs to be changed manually (for now)
angiboost_parser.add_argument('--outmo', default = "./saved_models/model", help="output model path, e.g. ./saved_models/xxxxx")

# model configuration
# model name, available: [unet3d, aspp, atrous]
angiboost_parser.add_argument('--mo', type=str, default="unet3d", help=argparse.SUPPRESS)
# input channel for unet 3d
angiboost_parser.add_argument('--ic', type=int, default=1, help=argparse.SUPPRESS)
# output channel for unet 3d
angiboost_parser.add_argument('--oc', type=int, default=1, help=argparse.SUPPRESS)
# number of filters for each layer in unet 3d
angiboost_parser.add_argument('--fil', type=int, default=16, help=argparse.SUPPRESS)

# Training configurations
angiboost_parser.add_argument('--lr', type=float, default=1e-3, help="learning rate, dtype: float, default=1e-3")
angiboost_parser.add_argument('--ep', type=int, default=1000, help="epoch number (times of iteration), dtype: int, default=16")

# expected size after zooming
angiboost_parser.add_argument('--osz', type=tuple, default=(64,64,64), help=argparse.SUPPRESS)
# optimizer type, available: [sgd, adam]
angiboost_parser.add_argument('--op', type=str, default="adam", help=argparse.SUPPRESS)
# loss metric type, available: [bce, dice, tver]
angiboost_parser.add_argument('--loss_m', type=str, default="tver", help=argparse.SUPPRESS)

# Optimizer tuning
# Decays the learning rate of each parameter group by this ratio, dtype: float
angiboost_parser.add_argument('--optim_gamma', type=float, default=0.95, help=argparse.SUPPRESS)
# Number of steps with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 steps with no improvement, and will only decrease the LR after the 3rd step if the loss still hasn’t improved then. Default: 10.
# Discarded feature (06/20/2023)
# angiboost_parser.add_argument('--optim_patience', type=int, default=100, help=argparse.SUPPRESS)

# Augmentation mode, available : [on, off, test, mode1]
angiboost_parser.add_argument('--aug_mode', type=str, default="mode1", help=argparse.SUPPRESS)

# batch size multiplier
angiboost_parser.add_argument('--batch_mul', type=int, default=4, help=argparse.SUPPRESS)

# postprocessing / threshold
# hard thresholding value
angiboost_parser.add_argument('--thresh', type=float, default=0.1, help="binary threshold for the probability map after prediction, default=0.1")
# connected components analysis threshold value (denoising)
angiboost_parser.add_argument('--cc', type=int, default=10, help="connected components analysis threshold value (denoising), default=10")

args = angiboost_parser.parse_args()
4 changes: 2 additions & 2 deletions utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def one_img_process(self, img_name, load_model, thresh, connect_thresh, mip_flag
# save the maximum intensity projection as jpg file
if mip_flag == True:
mip = np.max(postprocessed_output, axis=2)
# save_mip_path_post = os.path.join(self.output_path, img_name.split('.')[0], ".jpg")
save_mip_path_post = self.output_path + img_name.split('.')[0] + ".jpg"
save_mip_path_post = os.path.join(self.output_path, img_name.split('.')[0]) + ".jpg"
# save_mip_path_post = self.output_path + img_name.split('.')[0] + ".jpg"
#rotate the mip 90 degrees, counterclockwise
mip = np.rot90(mip, axes=(0, 1))
plt.imsave(save_mip_path_post, mip, cmap='gray')
Expand Down

0 comments on commit ca62c4e

Please sign in to comment.