Skip to content

Commit

Permalink
Merge pull request #238 from scil-vital/atheb/clear_cache
Browse files Browse the repository at this point in the history
ENH: make cache clearing optional
  • Loading branch information
EmmaRenauld authored Oct 8, 2024
2 parents 04639fa + aaf48f2 commit ce0a2d7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
33 changes: 23 additions & 10 deletions dwi_ml/data/processing/volume/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def torch_nearest_neighbor_interpolation(volume: torch.Tensor,


def torch_trilinear_interpolation(volume: torch.Tensor,
coords_vox_corner: torch.Tensor):
coords_vox_corner: torch.Tensor,
clear_cache=True):
"""Evaluates the data volume at given coordinates using trilinear
interpolation on a torch tensor.
Expand All @@ -71,6 +72,9 @@ def torch_trilinear_interpolation(volume: torch.Tensor,
The input volume to interpolate from
coords_vox_corner : torch.Tensor with shape (N,3)
The coordinates where to interpolate. (Origin = corner, space = vox).
clear_cache : bool
If True, will clear the cache after interpolation. This can be useful
to save memory, but will slow down the function.
Returns
-------
Expand Down Expand Up @@ -146,7 +150,8 @@ def torch_trilinear_interpolation(volume: torch.Tensor,
# p: of shape n x 8 x features
# Q1: n x 8 x 1
# This can save a bit of space.
torch.cuda.empty_cache()
if clear_cache:
torch.cuda.empty_cache()

# return torch.sum(p * Q1, dim=1)
# Able to have bigger batches by avoiding 3D matrix.
Expand All @@ -164,7 +169,8 @@ def torch_trilinear_interpolation(volume: torch.Tensor,


def interpolate_volume_in_neighborhood(
volume_as_tensor, coords_vox_corner, neighborhood_vectors_vox=None):
volume_as_tensor, coords_vox_corner, neighborhood_vectors_vox=None,
clear_cache=True):
"""
Params
------
Expand All @@ -178,6 +184,9 @@ def interpolate_volume_in_neighborhood(
The neighboors to add to each coord. Do not include the current point
([0,0,0]). Values are considered in the same space as
coords_vox_corner, and should thus be in voxel space.
clear_cache: bool
If True, will clear the cache after interpolation. This can be useful
to save memory, but will slow down the function.
Returns
-------
Expand All @@ -202,7 +211,8 @@ def interpolate_volume_in_neighborhood(
# DWI data features for each neighbor are concatenated.
# Result is of shape: (M * (N+1), F).
flat_subj_x_data = torch_trilinear_interpolation(volume_as_tensor,
coords_vox_corner)
coords_vox_corner,
clear_cache)

# Neighbors become new features of the current point.
# Reshape signal into (M, (N+1)*F))
Expand All @@ -211,20 +221,23 @@ def interpolate_volume_in_neighborhood(

else: # No neighborhood:
subj_x_data = torch_trilinear_interpolation(volume_as_tensor,
coords_vox_corner)
coords_vox_corner,
clear_cache)

if volume_as_tensor.is_cuda:
logging.debug("Emptying cache now. Can be a little slow but saves A "
"LOT of memory for our\n current trilinear "
"interpolation. (We could improve code \na little but "
"would loose a lot of readability).")

# Ex: with 2000 streamlines (134 000 points), with neighborhood axis
# [1 2] (13 neighbors), 47 features per point, we remove 6.5 GB of
# memory!
logging.debug("Torch trilinear: emptying cache. Before:")
log_gpu_memory_usage(logging.getLogger())
torch.cuda.empty_cache()
logging.debug("After")
log_gpu_memory_usage(logging.getLogger())
if clear_cache:
logging.debug("Torch trilinear: emptying cache. Before:")
log_gpu_memory_usage(logging.getLogger())
torch.cuda.empty_cache()
logging.debug("After")
log_gpu_memory_usage(logging.getLogger())

return subj_x_data, coords_vox_corner
8 changes: 5 additions & 3 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def forward(self, inputs, target_streamlines: List[torch.tensor]):

class MainModelOneInput(MainModelAbstract):
def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset,
subj_idx, input_group_idx, prepare_mask=False):
subj_idx, input_group_idx, prepare_mask=False,
clear_cache=True):
"""
These params are passed by either the batch loader or the propagator,
which manage the data.
Expand Down Expand Up @@ -495,10 +496,11 @@ def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset,
if isinstance(self, ModelWithNeighborhood):
# Adding neighborhood.
subj_x_data, coords_torch = interpolate_volume_in_neighborhood(
data_tensor, flat_subj_x_coords, self.neighborhood_vectors)
data_tensor, flat_subj_x_coords, self.neighborhood_vectors,
clear_cache=clear_cache)
else:
subj_x_data, coords_torch = interpolate_volume_in_neighborhood(
data_tensor, flat_subj_x_coords, None)
data_tensor, flat_subj_x_coords, None, clear_cache=clear_cache)

# Split the flattened signal back to streamlines
lengths = [len(s) for s in streamlines]
Expand Down
20 changes: 18 additions & 2 deletions dwi_ml/tracking/tracking_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,27 @@ def is_vox_corner_in_bound(self, xyz: torch.Tensor):
torch.any(torch.less(xyz, self.lower_bound), dim=-1),
torch.any(torch.greater_equal(xyz, self.higher_bound), dim=-1))

def get_value_at_vox_corner_coordinate(self, xyz, interpolation):
def get_value_at_vox_corner_coordinate(
self, xyz, interpolation, clear_cache=True
):
""" Get the value at the voxel corner coordinate.
Parameters
----------
xyz : torch.Tensor
Coordinates in voxel space, corner origin. Of shape [n, 3].
interpolation : str
Either 'nearest' or 'trilinear'.
clear_cache : bool
Whether to clear the cache after computing the interpolation.
Can be useful to save memory but will slow down the function.
Only used if interpolation is 'trilinear'.
"""

if interpolation == 'nearest':
return torch_nearest_neighbor_interpolation(self.data, xyz)
else:
return torch_trilinear_interpolation(self.data, xyz)
return torch_trilinear_interpolation(self.data, xyz, clear_cache)

def is_vox_corner_in_mask(self, xyz):
# Clipping to bound.
Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __init__(self, input_group_name, **kw):
.format(input_group_name,
self.dataset.volume_groups))
self.input_group_idx = idx
self.clear_cache = True

@property
def params_for_checkpoint(self):
Expand Down Expand Up @@ -448,7 +449,7 @@ def load_batch_inputs(self, batch_streamlines: List[torch.tensor],
# before adding streamline to batch.
subbatch_x_data = self.model.prepare_batch_one_input(
streamlines, self.context_subset, subj,
self.input_group_idx)
self.input_group_idx, clear_cache=self.clear_cache)

batch_x_data.extend(subbatch_x_data)

Expand Down

0 comments on commit ce0a2d7

Please sign in to comment.