Skip to content

Commit

Permalink
Fix tests for L2T, GV phase, from adding validation set
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed May 14, 2024
1 parent 3b13599 commit 702260b
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 74 deletions.
35 changes: 35 additions & 0 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
# Find idx of streamline group
self.streamline_group_idx = self.dataset.streamline_groups.index(
self.streamline_group_name)
self.data_contains_connectivity = \
self.dataset.streamlines_contain_connectivity[
self.streamline_group_idx]

# Set random numbers
self.rng = rng
Expand Down Expand Up @@ -314,6 +317,38 @@ def load_batch_streamlines(

return batch_streamlines, final_s_ids_per_subj

def load_batch_connectivity_matrices(
self, streamline_ids_per_subj: Dict[int, slice]):
if not self.data_contains_connectivity:
raise ValueError("No connectivity matrix in this dataset.")

# The batch's streamline ids will change throughout processing because
# of data augmentation, so we need to do it subject by subject to
# keep track of the streamline ids. These final ids will correspond to
# the loaded, processed streamlines, not to the ids in the hdf5 file.
subjs = list(streamline_ids_per_subj.keys())
nb_subjs = len(subjs)
matrices = [None] * nb_subjs
volume_sizes = [None] * nb_subjs
connectivity_nb_blocs = [None] * nb_subjs
connectivity_labels = [None] * nb_subjs
for i, subj in enumerate(subjs):
# No cache for the sft data. Accessing it directly.
# Note: If this is used through the dataloader, multiprocessing
# is used. Each process will open a handle.
subj_data = \
self.context_subset.subjs_data_list.get_subj_with_handle(subj)
subj_sft_data = subj_data.sft_data_list[self.streamline_group_idx]

# We could access it only at required index, maybe. Loading the
# whole matrix here.
(matrices[i], volume_sizes[i],
connectivity_nb_blocs[i], connectivity_labels[i]) = \
subj_sft_data.get_connectivity_matrix_and_info()

return (matrices, volume_sizes,
connectivity_nb_blocs, connectivity_labels)


class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader):
"""
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/training/projects/learn2track_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
from dwi_ml.tracking.io_utils import prepare_tracking_mask
from dwi_ml.tracking.propagation import propagate_multiple_lines
from dwi_ml.training.with_generation.trainer import \
from dwi_ml.training.trainers_withGV import \
DWIMLTrainerForTrackingOneInput

logger = logging.getLogger('trainer_logger')
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/training/projects/transformer_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dwi_ml.tracking.io_utils import prepare_tracking_mask
from dwi_ml.tracking.propagation import propagate_multiple_lines

from dwi_ml.training.with_generation.trainer import \
from dwi_ml.training.trainers_withGV import \
DWIMLTrainerForTrackingOneInput


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
Adds a tracking step to verify the generation process. Metrics on the
streamlines are:
Adds a generation-validation phase: a tracking step. Metrics on the streamlines
are:
- Very good / acceptable / very far IS threshold:
Percentage of streamlines ending inside a radius of 15 / 25 / 40 voxels of
Expand Down Expand Up @@ -47,10 +47,9 @@
from dwi_ml.models.main_models import ModelWithDirectionGetter
from dwi_ml.tracking.propagation import propagate_multiple_lines
from dwi_ml.tracking.io_utils import prepare_tracking_mask
from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput
from dwi_ml.training.trainers import DWIMLTrainerOneInput
from dwi_ml.training.utils.monitoring import BatchHistoryMonitor
from dwi_ml.training.with_generation.batch_loader import \
DWIMLBatchLoaderWithConnectivity

logger = logging.getLogger('train_logger')

Expand All @@ -65,7 +64,7 @@

class DWIMLTrainerForTrackingOneInput(DWIMLTrainerOneInput):
model: ModelWithDirectionGetter
batch_loader: DWIMLBatchLoaderWithConnectivity
batch_loader: DWIMLBatchLoaderOneInput

def __init__(self, add_a_tracking_validation_phase: bool = False,
tracking_phase_frequency: int = 1,
Expand Down Expand Up @@ -105,6 +104,12 @@ def __init__(self, add_a_tracking_validation_phase: bool = False,

self.compute_connectivity = self.batch_loader.data_contains_connectivity

# -------- Checks
if add_a_tracking_validation_phase and \
tracking_phase_mask_group is None:
raise NotImplementedError("Not ready to run without a tracking "
"mask.")

# -------- Monitors
# At training time: only the one metric used for training.
# At validation time: A lot of exploratory metrics monitors.
Expand Down Expand Up @@ -177,8 +182,7 @@ def validate_one_batch(self, data, epoch):
(gen_n, mean_final_dist, mean_clipped_final_dist,
percent_IS_very_good, percent_IS_acceptable,
percent_IS_very_far, diverging_pnt, connectivity) = \
self.validation_generation_one_batch(
data, compute_all_scores=True)
self.gv_phase_one_batch(data, compute_all_scores=True)

self.tracking_very_good_IS_monitor.update(
percent_IS_very_good, weight=gen_n)
Expand All @@ -194,8 +198,9 @@ def validate_one_batch(self, data, epoch):
self.tracking_valid_diverg_monitor.update(
diverging_pnt, weight=gen_n)

self.tracking_connectivity_score_monitor.update(
connectivity, weight=gen_n)
if self.compute_connectivity:
self.tracking_connectivity_score_monitor.update(
connectivity, weight=gen_n)
elif len(self.tracking_mean_final_distance_monitor.average_per_epoch) == 0:
logger.info("Skipping tracking-like generation validation "
"from batch. No values yet: adding fake initial "
Expand All @@ -216,7 +221,8 @@ def validate_one_batch(self, data, epoch):
self.tracking_clipped_final_distance_monitor.update(
ACCEPTABLE_THRESHOLD)

self.tracking_connectivity_score_monitor.update(1)
if self.compute_connectivity:
self.tracking_connectivity_score_monitor.update(1)
else:
logger.info("Skipping tracking-like generation validation "
"from batch. Copying previous epoch's values.")
Expand All @@ -230,7 +236,7 @@ def validate_one_batch(self, data, epoch):
self.tracking_connectivity_score_monitor]:
monitor.update(monitor.average_per_epoch[-1])

def validation_generation_one_batch(self, data, compute_all_scores=False):
def gv_phase_one_batch(self, data, compute_all_scores=False):
"""
Use tractography to generate streamlines starting from the "true"
seeds and first few segments. Expected results are the batch's
Expand Down Expand Up @@ -304,12 +310,12 @@ def validation_generation_one_batch(self, data, compute_all_scores=False):
total_point += abs(100 - div_point)
diverging_point = total_point / len(lines)

invalid_ratio_severe = invalid_ratio_severe.cpu().numpy().astype(np.float32)
invalid_ratio_acceptable = invalid_ratio_acceptable.cpu().numpy().astype(np.float32)
invalid_ratio_loose = invalid_ratio_loose.cpu().numpy().astype(np.float32)
final_dist = final_dist.cpu().numpy().astype(np.float32)
final_dist_clipped = final_dist_clipped.cpu().numpy().astype(np.float32)
diverging_point = np.asarray(diverging_point, dtype=np.float32)
invalid_ratio_severe = invalid_ratio_severe.item()
invalid_ratio_acceptable = invalid_ratio_acceptable.item()
invalid_ratio_loose = invalid_ratio_loose.item()
final_dist = final_dist.item()
final_dist_clipped = final_dist_clipped.item()

return (len(lines), final_dist, final_dist_clipped,
invalid_ratio_severe, invalid_ratio_acceptable,
invalid_ratio_loose, diverging_point,
Expand Down
5 changes: 2 additions & 3 deletions dwi_ml/training/utils/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from dwi_ml.experiment_utils.prints import format_dict_to_str
from dwi_ml.experiment_utils.timer import Timer
from dwi_ml.training.with_generation.batch_loader import \
DWIMLBatchLoaderWithConnectivity
from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput


def add_args_batch_loader(p: argparse.ArgumentParser):
Expand Down Expand Up @@ -44,7 +43,7 @@ def add_args_batch_loader(p: argparse.ArgumentParser):
def prepare_batch_loader(dataset, model, args, sub_loggers_level):
# Preparing the batch loader.
with Timer("\nPreparing batch loader...", newline=True, color='pink'):
batch_loader = DWIMLBatchLoaderWithConnectivity(
batch_loader = DWIMLBatchLoaderOneInput(
dataset=dataset, model=model,
input_group_name=args.input_group_name,
streamline_group_name=args.streamline_group_name,
Expand Down
Empty file.
43 changes: 0 additions & 43 deletions dwi_ml/training/with_generation/batch_loader.py

This file was deleted.

5 changes: 2 additions & 3 deletions scripts_python/l2t_resume_training_from_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
from dwi_ml.experiment_utils.timer import Timer
from dwi_ml.io_utils import add_verbose_arg
from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer
from dwi_ml.training.utils.experiment import add_args_resuming_experiment
from dwi_ml.training.utils.trainer import run_experiment
from dwi_ml.training.with_generation.batch_loader import \
DWIMLBatchLoaderWithConnectivity


def prepare_arg_parser():
Expand Down Expand Up @@ -60,7 +59,7 @@ def init_from_checkpoint(args, checkpoint_path):
dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level)

# Prepare batch loader
batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint(
batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint(
dataset, model, checkpoint_state['batch_loader_params'],
sub_loggers_level)

Expand Down
5 changes: 2 additions & 3 deletions scripts_python/l2t_update_deprecated_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from dwi_ml.experiment_utils.prints import format_dict_to_str
from dwi_ml.io_utils import add_verbose_arg
from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer
from dwi_ml.training.with_generation.batch_loader import \
DWIMLBatchLoaderWithConnectivity


def prepare_arg_parser():
Expand Down Expand Up @@ -217,7 +216,7 @@ def fix_checkpoint(args, model):
# Init stuff will succeed if ok.
batch_sampler = DWIMLBatchIDSampler.init_from_checkpoint(
dataset, checkpoint_state['batch_sampler_params'])
batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint(
batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint(
dataset, model, checkpoint_state['batch_loader_params'])
experiments_path, experiment_name = os.path.split(args.out_experiment)
trainer = Learn2TrackTrainer.init_from_checkpoint(
Expand Down
1 change: 1 addition & 0 deletions scripts_python/tests/test_all_steps_l2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,6 @@ def test_training_with_generation_validation(script_runner, experiments_path):
'--max_batches_per_epoch_validation', '1',
'-v', 'INFO', '--step_size', '0.5',
'--add_a_tracking_validation_phase',
'--tracking_mask', 'wm_mask',
'--tracking_phase_frequency', '1', option)
assert ret.success
5 changes: 2 additions & 3 deletions scripts_python/tt_resume_training_from_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
from dwi_ml.experiment_utils.timer import Timer
from dwi_ml.io_utils import add_verbose_arg, verify_which_model_in_path
from dwi_ml.models.projects.transformer_models import find_transformer_class
from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.projects.transformer_trainer import TransformerTrainer
from dwi_ml.training.utils.experiment import add_args_resuming_experiment
from dwi_ml.training.utils.trainer import run_experiment
from dwi_ml.training.with_generation.batch_loader import \
DWIMLBatchLoaderWithConnectivity


def prepare_arg_parser():
Expand Down Expand Up @@ -62,7 +61,7 @@ def init_from_checkpoint(args, checkpoint_path):
dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level)

# Prepare batch loader
batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint(
batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint(
dataset, model, checkpoint_state['batch_loader_params'],
sub_loggers_level)

Expand Down

0 comments on commit 702260b

Please sign in to comment.