Skip to content

Commit

Permalink
reconfigured files
Browse files Browse the repository at this point in the history
  • Loading branch information
KMarshallX committed Feb 19, 2024
1 parent cadc4d2 commit 2b9b451
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 27 deletions.
7 changes: 4 additions & 3 deletions boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
Last Edited: 22/10/2023
"""

import config.boost_config as boost_config
from utils.module_utils import preprocess_procedure, make_prediction
from utils.train_utils import *
from config import boost_config
from utils import preprocess_procedure, make_prediction
from utils import TTA_Training
import os

args = boost_config.args
# input images & labels
Expand Down
4 changes: 4 additions & 0 deletions config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import adapt_config
from . import boost_config
from . import pred_config
from . import train_config
4 changes: 4 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .aspp import CustomSegmentationNetwork
from .asppcnn import ASPPCNN
from .ra_unet import MainArchitecture
from .unet_3d import Unet
4 changes: 2 additions & 2 deletions prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
"""

import os
import config.pred_config as pred_config
from utils.module_utils import preprocess_procedure, make_prediction
from config import pred_config
from utils import preprocess_procedure, make_prediction

args = pred_config.pred_parser.parse_args()

Expand Down
7 changes: 3 additions & 4 deletions test_time_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
"""

import os
import config.adapt_config as adapt_config
from utils.module_utils import preprocess_procedure
from utils.unet_utils import *
from utils.train_utils import TTA_Training
from config import adapt_config
from utils import preprocess_procedure
from utils import TTA_Training


args = adapt_config.args
Expand Down
7 changes: 3 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
Last Edited: 18/10/2023
"""

import config.train_config as train_config
from utils.module_utils import preprocess_procedure
from utils.train_utils import *
from utils.single_data_loader import multi_channel_loader
from config import train_config
from utils import preprocess_procedure
from utils import TTA_Training

args = train_config.args
# input images & labels
Expand Down
5 changes: 5 additions & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .eval_utils import eval_scores, out_file_reader, mra_deskull
from .module_utils import make_prediction, preprocess_procedure
from .single_data_loader import multi_channel_loader
from .train_utils import TTA_Training
from .unet_utils import *
4 changes: 2 additions & 2 deletions utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from patchify import patchify, unpatchify
import cc3d

from utils.unet_utils import *
from models.unet_3d import Unet
from .unet_utils import *
from models import Unet

class preprocess:
"""
Expand Down
2 changes: 1 addition & 1 deletion utils/single_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import nibabel as nib
import os

from .unet_utils import RandomCrop3D, standardiser, normaliser
from .unet_utils import RandomCrop3D, standardiser

class single_channel_loader:
def __init__(self, raw_img, seg_img, patch_size, step):
Expand Down
11 changes: 4 additions & 7 deletions utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@
import shutil
import torch
from tqdm import tqdm
from utils.unet_utils import *
from utils.single_data_loader import single_channel_loader, multi_channel_loader
from utils.module_utils import prediction_and_postprocess
from models.unet_3d import Unet
from models.asppcnn import ASPPCNN
from models.aspp import CustomSegmentationNetwork
from models.ra_unet import MainArchitecture
from .unet_utils import *
from .single_data_loader import single_channel_loader, multi_channel_loader
from .module_utils import prediction_and_postprocess
from models import Unet, ASPPCNN, CustomSegmentationNetwork, MainArchitecture

def model_chosen(model_name, in_chan, out_chan, filter_num):
if model_name == "unet3d":
Expand Down
4 changes: 0 additions & 4 deletions utils/unet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
import torch.nn.functional as F
import numpy as np
import scipy.ndimage as scind
import nibabel as nib
from tqdm import tqdm
from patchify import patchify, unpatchify
import os

class DiceLoss(nn.Module):
def __init__(self, smooth = 1e-4):
Expand Down

0 comments on commit 2b9b451

Please sign in to comment.