Skip to content

Commit

Permalink
Merge pull request #234 from EmmaRenauld/update_to_scilpy2
Browse files Browse the repository at this point in the history
Update to scilpy 2.0 and torch 2.2.0
  • Loading branch information
EmmaRenauld authored May 17, 2024
2 parents e21a7c1 + 1ef48e8 commit 38ae414
Show file tree
Hide file tree
Showing 14 changed files with 24 additions and 32 deletions.
3 changes: 1 addition & 2 deletions dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import numpy as np

from scilpy.tractograms.streamline_operations import \
resample_streamlines_step_size
from scilpy.utils.streamlines import compress_sft
resample_streamlines_step_size, compress_sft


def resample_or_compress(sft, step_size_mm: float = None,
Expand Down
1 change: 1 addition & 0 deletions dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def add_resample_or_compress_arg(p: ArgumentParser):
help="Step size to resample the data (in mm). Default: None")
g.add_argument(
'--compress', type=float, metavar='r', const=0.01, nargs='?',
dest='compress_th',
help="Compression ratio. Default: None. Default if set: 0.01.\n"
"If neither step_size nor compress are chosen, streamlines "
"will be kept \nas they are.")
Expand Down
4 changes: 2 additions & 2 deletions dwi_ml/testing/projects/tt_visu_colored_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.io.streamline import save_tractogram

from scilpy.viz.utils import get_colormap
from scilpy.viz.color import get_lookup_table

from dwi_ml.testing.projects.tt_visu_utils import (
get_visu_params_from_options,
Expand Down Expand Up @@ -234,7 +234,7 @@ def color_sft_x_y_projections(

def _color_sft_from_dpp(sft, key, cmap='viridis', vmin=None, vmax=None,
prepare_fig: bool = False, title=None, **kw):
cmap = get_colormap(cmap)
cmap = get_lookup_table(cmap)
tmp = [np.squeeze(sft.data_per_point[key][s]) for s in range(len(sft))]
data = np.hstack(tmp)

Expand Down
4 changes: 2 additions & 2 deletions dwi_ml/testing/projects/tt_visu_submethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def load_data_run_model(parser, args, model: AbstractTransformerModel,
# Resampling streamlines to a fixed step size, if any
logging.debug(" Resampling: {}".format(args.step_size))
sft = resample_streamlines_step_size(sft, step_size=args.step_size)
if args.compress:
logging.debug(" Compressing: {}".format(args.compress))
if args.compress_th:
logging.debug(" Compressing: {}".format(args.compress_th))
sft = compress_sft(sft)

# To tensor
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/testing/projects/tt_visu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy.ndimage import zoom
from tqdm import tqdm

from scilpy.io.fetcher import get_home as get_scilpy_folder
from scilpy import get_home as get_scilpy_folder


THRESH_IMPORTANT = {
Expand Down
7 changes: 3 additions & 4 deletions dwi_ml/unit_tests/utils/data_and_models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch

from scilpy.io.fetcher import fetch_data, get_home
from scilpy import get_home
from scilpy.io.fetcher import fetch_data

from dwi_ml.data.processing.streamlines.post_processing import \
compute_directions
Expand All @@ -30,9 +31,7 @@ def fetch_testing_data():
# Access to the file dwi_ml.zip:
# https://drive.google.com/uc?id=1beRWAorhaINCncttgwqVAP2rNOfx842Q
name_as_dict = {
'data_for_tests_dwi_ml.zip':
['1beRWAorhaINCncttgwqVAP2rNOfx842Q',
'da6c94fbef7ac13029acdb8b94325096']}
'data_for_tests_dwi_ml.zip': "da6c94fbef7ac13029acdb8b94325096"}
fetch_data(name_as_dict)

return testing_data_dir
Expand Down
17 changes: 5 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,16 @@
# Main dependency: scilpy
# Scilpy and comet_ml both require requests. In comet: >=2.18.*,
# which installs a version >2.28. Adding request version explicitely.
#
# Changed many times the scilpy version. Tried to have latest fixed version,
# but currently, working on beluga only when installing manually. But using
# the master is not good; changes too fast for us. Using a fixed commit for
# now.
# To using a commit preceding all recent changes in scilpy's test data
# management: d20d3d4917d40f698aa975f64a575dda34e0c89c
# -------
requests==2.28.*
dipy==1.8.0
-e git+https://github.com/scilus/scilpy.git@d20d3d4917d40f698aa975f64a575dda34e0c89c#egg=scilpy
dipy==1.9.*
scilpy==2.0.2

# -------
# Other important dependencies
# -------
bertviz==1.4.0 # For transformer's visu
torch==1.13.*
torch==2.2.0
tqdm==4.64.*
comet-ml>=3.22.0
contextlib2==21.6.0
Expand All @@ -30,17 +23,17 @@ jupyter>=1.0.0
IProgress>=0.4 # For jupyter with tdqm
nested_lookup==0.2.25 # For lists management
pynvml>=11.5.0
scikit-image

# -------
# Necessary but should be installed with scilpy (Last check: 01/2024):
# Necessary but should be installed with scilpy (Last check: 04/2024):
# -------
future==0.18.*
h5py==3.7.* # h5py must absolutely be >2.4: that's when it became thread-safe
matplotlib==3.6.* # Hint: If matplotlib fails, you may try to install pyQt5.
nibabel==5.2.*
numpy==1.23.*
scipy==1.9.*
scikit-image==0.22.*


# --------------- Notes to developers
Expand Down
2 changes: 1 addition & 1 deletion scripts_python/dwiml_compute_loss_copy_previous.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def main():
# 1. Prepare fake model
dg_args = check_args_direction_getter(args)
model = CopyPrevDirModel(args.dg_key, dg_args, args.skip_first_point,
args.step_size, args.compress)
args.step_size, args.compress_th)
model.set_context('visu')

# 2. Load data through the tester
Expand Down
2 changes: 1 addition & 1 deletion scripts_python/dwiml_create_hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def prepare_hdf5_creator(args):
# Instantiate a creator and perform checks
creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file,
training_subjs, validation_subjs, testing_subjs,
groups_config, args.step_size, args.compress,
groups_config, args.step_size, args.compress_th,
args.enforce_files_presence,
args.save_intermediate, intermediate_subdir)

Expand Down
2 changes: 1 addition & 1 deletion scripts_python/dwiml_visualize_noise_on_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main():
subj_sft_data = subj_data.sft_data_list[streamline_group_idx]
sft = subj_sft_data.as_sft()

sft = resample_or_compress(sft, args.step_size, args.compress)
sft = resample_or_compress(sft, args.step_size, args.compress_th)
sft.to_vox()
sft.to_corner()

Expand Down
4 changes: 2 additions & 2 deletions scripts_python/l2t_track_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def prepare_tracker(parser, args):
dataset=subset, subj_idx=0, model=model, mask=tracking_mask,
seed_generator=seed_generator, nbr_seeds=nbr_seeds,
min_len_mm=args.min_length, max_len_mm=args.max_length,
compression_th=args.compress, nbr_processes=args.nbr_processes,
compression_th=args.compress_th, nbr_processes=args.nbr_processes,
save_seeds=args.save_seeds, rng_seed=args.rng_seed,
track_forward_only=args.track_forward_only,
step_size_mm=args.step_size, algo=args.algo, theta=theta,
Expand Down Expand Up @@ -130,7 +130,7 @@ def main():
assert_outputs_exist(parser, args, args.out_tractogram)

verify_streamline_length_options(parser, args)
verify_compression_th(args.compress)
verify_compression_th(args.compress_th)
verify_seed_options(parser, args)

tracker, ref = prepare_tracker(parser, args)
Expand Down
2 changes: 1 addition & 1 deletion scripts_python/l2t_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def init_from_args(args, sub_loggers_level):
# INPUTS: verifying args
model = Learn2TrackModel(
experiment_name=args.experiment_name,
step_size=args.step_size, compress_lines=args.compress,
step_size=args.step_size, compress_lines=args.compress_th,
# PREVIOUS DIRS
prev_dirs_embedding_key=args.prev_dirs_embedding_key,
prev_dirs_embedded_size=args.prev_dirs_embedded_size,
Expand Down
4 changes: 2 additions & 2 deletions scripts_python/tt_track_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def prepare_tracker(parser, args):
dataset=subset, subj_idx=0, model=model, mask=tracking_mask,
seed_generator=seed_generator, nbr_seeds=nbr_seeds,
min_len_mm=args.min_length, max_len_mm=args.max_length,
compression_th=args.compress, nbr_processes=args.nbr_processes,
compression_th=args.compress_th, nbr_processes=args.nbr_processes,
save_seeds=args.save_seeds, rng_seed=args.rng_seed,
track_forward_only=args.track_forward_only,
step_size_mm=args.step_size, algo=args.algo, theta=theta,
Expand Down Expand Up @@ -138,7 +138,7 @@ def main():
assert_outputs_exist(parser, args, args.out_tractogram)

verify_streamline_length_options(parser, args)
verify_compression_th(args.compress)
verify_compression_th(args.compress_th)
verify_seed_options(parser, args)

tracker, ref = prepare_tracker(parser, args)
Expand Down
2 changes: 1 addition & 1 deletion scripts_python/tt_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def init_from_args(args, sub_loggers_level):
with Timer("\n\nPreparing model", newline=True, color='yellow'):
model = cls(
experiment_name=args.experiment_name,
step_size=args.step_size, compress_lines=args.compress,
step_size=args.step_size, compress_lines=args.compress_th,
# Concerning inputs:
max_len=args.max_len, nb_features=args.nb_features,
positional_encoding_key=args.position_encoding,
Expand Down

0 comments on commit 38ae414

Please sign in to comment.