-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #222 from KMarshallX/openrecon
openrecon branch
- Loading branch information
Showing
4 changed files
with
144 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,3 +23,5 @@ config/train_test_config.py | |
|
||
*.json | ||
*.npy | ||
|
||
t2boost.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#!/usr/bin/env python3 | ||
|
||
""" | ||
Angio Boost module - train a model on single subject from scratch, then make prediction | ||
Editor: Marshall Xu | ||
Last Edited: 04/10/2024 | ||
""" | ||
|
||
import config.angiboost_config as angiboost_config | ||
from utils import preprocess_procedure, make_prediction | ||
from utils import TTA_Training | ||
import os | ||
|
||
args = angiboost_config.args | ||
# input images & labels | ||
ds_path = args.ds_path # needed as input argument | ||
ps_path = args.ps_path | ||
seg_path = args.lb_path # needed as input argument (places to store the initial segmentation) | ||
prep_mode = args.prep_mode # needed as input argument | ||
outmo_path = args.outmo # needed as input argument | ||
out_path = args.out_path # needed as input argument | ||
pretrained = args.pretrained # needed as input argument | ||
|
||
if os.path.exists(seg_path) == False: | ||
print(f"{seg_path} does not exist.") | ||
os.mkdir(seg_path) | ||
print(f"{seg_path} has been created!") | ||
|
||
if os.path.exists(out_path) == False: | ||
print(f"{out_path} does not exist.") | ||
os.mkdir(out_path) | ||
print(f"{out_path} has been created!") | ||
|
||
# when the preprocess is skipped, | ||
# directly take the raw data for prediction | ||
if prep_mode == 4: | ||
ps_path = ds_path | ||
|
||
if __name__ == "__main__": | ||
print("Boosting session will start shortly..") | ||
print("Parameters Info:\n*************************************************************\n") | ||
print(f"Input image path: {ds_path}, Segmentation path: {seg_path}, Prep_mode: {prep_mode}\n") | ||
print(f"Epoch number: {args.ep}, Learning rate: {args.lr} \n") | ||
|
||
# preprocess procedure | ||
preprocess_procedure(ds_path, ps_path, prep_mode) | ||
|
||
# genereate the initial segmentation | ||
make_prediction(args.mo, args.ic, args.oc, | ||
args.fil, ps_path, seg_path, | ||
args.thresh, args.cc, pretrained, | ||
mip_flag=False) | ||
|
||
# initialize the training process | ||
train_process = TTA_Training(args.loss_m, args.mo, | ||
args.ic, args.oc, args.fil, | ||
args.op, args.lr, | ||
args.optim_gamma, args.ep, | ||
args.batch_mul, | ||
args.osz, args.aug_mode) | ||
|
||
# traning loop (this could be separate out ) | ||
train_process.train(ps_path, seg_path, outmo_path) | ||
|
||
# make prediction | ||
make_prediction(args.mo, args.ic, args.oc, | ||
args.fil, ps_path, out_path, | ||
args.thresh, args.cc, outmo_path, | ||
mip_flag=True) | ||
|
||
print(f"Boosting session has been completed! Resultant segmentation has been saved to {out_path}.") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
argparse configuration for angiboost.py (openrecon) | ||
Editor: Marshall Xu | ||
Last Edited: 04/10/2024 | ||
""" | ||
#@TODO: remember change the optim_patience | ||
|
||
import argparse | ||
|
||
angiboost_parser = argparse.ArgumentParser(description="VesselBoost AngiBoost arguments") | ||
|
||
# Train.py | ||
# input /output (train.py) | ||
angiboost_parser.add_argument('--ds_path', default = "/ds_path/", help="input image path") | ||
angiboost_parser.add_argument('--lb_path', default = "/lb_path/", help="initially generated label path") | ||
angiboost_parser.add_argument('--ps_path', type=str, default = "/preprocessed_path/", help="path of the preprocessed data") | ||
angiboost_parser.add_argument('--out_path', type=str, default = "/out_path/", help="path of the output segmentation") | ||
angiboost_parser.add_argument('--pretrained', type=str, default = "/pretrained_model_path/", help="path of the prertrained model") | ||
|
||
# preprocessing mode | ||
angiboost_parser.add_argument('--prep_mode', type=int, default=4, help="Preprocessing mode options. prep_mode=1 : bias field correction only | prep_mode=2 : denoising only | prep_mode=3 : bfc + denoising | prep_mode=4 : no preprocessing applied") | ||
|
||
# The following needs to be changed manually (for now) | ||
angiboost_parser.add_argument('--outmo', default = "./saved_models/model", help="output model path, e.g. ./saved_models/xxxxx") | ||
|
||
# model configuration | ||
# model name, available: [unet3d, aspp, atrous] | ||
angiboost_parser.add_argument('--mo', type=str, default="unet3d", help=argparse.SUPPRESS) | ||
# input channel for unet 3d | ||
angiboost_parser.add_argument('--ic', type=int, default=1, help=argparse.SUPPRESS) | ||
# output channel for unet 3d | ||
angiboost_parser.add_argument('--oc', type=int, default=1, help=argparse.SUPPRESS) | ||
# number of filters for each layer in unet 3d | ||
angiboost_parser.add_argument('--fil', type=int, default=16, help=argparse.SUPPRESS) | ||
|
||
# Training configurations | ||
angiboost_parser.add_argument('--lr', type=float, default=1e-3, help="learning rate, dtype: float, default=1e-3") | ||
angiboost_parser.add_argument('--ep', type=int, default=1000, help="epoch number (times of iteration), dtype: int, default=16") | ||
|
||
# expected size after zooming | ||
angiboost_parser.add_argument('--osz', type=tuple, default=(64,64,64), help=argparse.SUPPRESS) | ||
# optimizer type, available: [sgd, adam] | ||
angiboost_parser.add_argument('--op', type=str, default="adam", help=argparse.SUPPRESS) | ||
# loss metric type, available: [bce, dice, tver] | ||
angiboost_parser.add_argument('--loss_m', type=str, default="tver", help=argparse.SUPPRESS) | ||
|
||
# Optimizer tuning | ||
# Decays the learning rate of each parameter group by this ratio, dtype: float | ||
angiboost_parser.add_argument('--optim_gamma', type=float, default=0.95, help=argparse.SUPPRESS) | ||
# Number of steps with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 steps with no improvement, and will only decrease the LR after the 3rd step if the loss still hasn’t improved then. Default: 10. | ||
# Discarded feature (06/20/2023) | ||
# angiboost_parser.add_argument('--optim_patience', type=int, default=100, help=argparse.SUPPRESS) | ||
|
||
# Augmentation mode, available : [on, off, test, mode1] | ||
angiboost_parser.add_argument('--aug_mode', type=str, default="mode1", help=argparse.SUPPRESS) | ||
|
||
# batch size multiplier | ||
angiboost_parser.add_argument('--batch_mul', type=int, default=4, help=argparse.SUPPRESS) | ||
|
||
# postprocessing / threshold | ||
# hard thresholding value | ||
angiboost_parser.add_argument('--thresh', type=float, default=0.1, help="binary threshold for the probability map after prediction, default=0.1") | ||
# connected components analysis threshold value (denoising) | ||
angiboost_parser.add_argument('--cc', type=int, default=10, help="connected components analysis threshold value (denoising), default=10") | ||
|
||
args = angiboost_parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters