Skip to content

Commit

Permalink
- added fuzzy case to SegmentationBasedAdapter (+activated gpu)
Browse files Browse the repository at this point in the history
- adapted segmentation_loader example
- refactored assert_equal_shapes test
  • Loading branch information
faberno committed Nov 18, 2024
1 parent 6949c2c commit 9ec96bc
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,22 @@ class SegmentationBasedAdapter(VolumeCreationAdapterBase):
def create_simulation_volume(self) -> dict:
volumes, x_dim_px, y_dim_px, z_dim_px = self.create_empty_volumes()
wavelength = self.global_settings[Tags.WAVELENGTH]
for key in volumes.keys():
volumes[key] = volumes[key].to('cpu')

segmentation_volume = self.component_settings[Tags.INPUT_SEGMENTATION_VOLUME]
segmentation_classes = np.unique(segmentation_volume, return_counts=False)
x_dim_seg_px, y_dim_seg_px, z_dim_seg_px = np.shape(segmentation_volume)
segmentation_volume = torch.tensor(self.component_settings[Tags.INPUT_SEGMENTATION_VOLUME], device=self.torch_device)
class_mapping = self.component_settings[Tags.SEGMENTATION_CLASS_MAPPING]

if torch.is_floating_point(segmentation_volume):
assert len(segmentation_volume.shape) == 4 and segmentation_volume.shape[0] == len(class_mapping), \
"Fuzzy segmentation must be a 4D array with the first dimension being the number of classes."
fuzzy = True
segmentation_classes = np.arange(segmentation_volume.shape[0])

else:
assert len(segmentation_volume.shape) == 3, "Hard segmentations must be a 3D array."
fuzzy = False
segmentation_classes = torch.unique(segmentation_volume, return_counts=False).cpu().numpy()

x_dim_seg_px, y_dim_seg_px, z_dim_seg_px = np.shape(segmentation_volume)[-3:]

if x_dim_px != x_dim_seg_px:
raise ValueError("x_dim of volumes and segmentation must perfectly match but was {} and {}"
Expand All @@ -38,16 +48,17 @@ def create_simulation_volume(self) -> dict:
raise ValueError("z_dim of volumes and segmentation must perfectly match but was {} and {}"
.format(z_dim_px, z_dim_seg_px))

class_mapping = self.component_settings[Tags.SEGMENTATION_CLASS_MAPPING]

for seg_class in segmentation_classes:
class_properties = class_mapping[seg_class].get_properties_for_wavelength(self.global_settings, wavelength)
for volume_key in volumes.keys():
if isinstance(class_properties[volume_key], (int, float)) or class_properties[volume_key] == None: # scalar
assigned_prop = class_properties[volume_key]
if assigned_prop is None:
assigned_prop = torch.nan
volumes[volume_key][segmentation_volume == seg_class] = assigned_prop
if fuzzy:
volumes[volume_key] += segmentation_volume[seg_class] * assigned_prop
else:
volumes[volume_key][segmentation_volume == seg_class] = assigned_prop
elif len(torch.Tensor.size(class_properties[volume_key])) == 3: # 3D map
assigned_prop = class_properties[volume_key][torch.tensor(segmentation_volume == seg_class)]
assigned_prop[assigned_prop is None] = torch.nan
Expand All @@ -57,6 +68,6 @@ def create_simulation_volume(self) -> dict:

# convert volumes back to CPU
for key in volumes.keys():
volumes[key] = volumes[key].numpy().astype(np.float64, copy=False)
volumes[key] = volumes[key].cpu().numpy().astype(np.float64, copy=False)

return volumes
8 changes: 3 additions & 5 deletions simpa/utils/quality_assurance/data_sanity_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ def assert_equal_shapes(numpy_arrays: list):
if len(numpy_arrays) < 2:
return

shapes = np.asarray([np.shape(_arr) for _arr in numpy_arrays]).astype(float)
mean = np.mean(shapes, axis=0)
for i in range(len(shapes)):
shapes[i, :] = shapes[i, :] - mean
first_array_shape = numpy_arrays[0].shape
equal = ([_arr.shape == first_array_shape for _arr in numpy_arrays])

if not np.sum(np.abs(shapes)) <= 1e-5:
if not all(equal):
raise AssertionError("The given volumes did not all have the same"
" dimensions. Please double check the simulation"
f" parameters. Called from {inspect.stack()[1].function}")
Expand Down
28 changes: 19 additions & 9 deletions simpa_examples/segmentation_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import simpa as sp
import numpy as np
from skimage.data import shepp_logan_phantom
from scipy.ndimage import zoom
from scipy.ndimage import zoom, gaussian_filter
from skimage.transform import resize

# FIXME temporary workaround for newest Intel architectures
Expand All @@ -20,8 +20,8 @@


@profile
def run_segmentation_loader(spacing: float | int = 1.0, input_spacing: float | int = 0.2, path_manager=None,
visualise: bool = True):
def run_segmentation_loader(spacing: float | int = 1.0, input_spacing: float | int = 0.2, fuzzy: bool = False,
path_manager=None, visualise: bool = True):
"""
:param spacing: The simulation spacing between voxels in mm
Expand All @@ -30,19 +30,28 @@ def run_segmentation_loader(spacing: float | int = 1.0, input_spacing: float | i
:param visualise: If VISUALIZE is set to True, the reconstruction result will be plotted
:return: a run through of the example
"""

if path_manager is None:
path_manager = sp.PathManager()

C = 11 # number of classes
label_mask = shepp_logan_phantom()

label_mask = np.digitize(label_mask, bins=np.linspace(0.0, 1.0, 11), right=True)
label_mask = np.digitize(label_mask, bins=np.linspace(0.0, 1.0, C), right=True)
label_mask = label_mask[100:300, 100:300]
label_mask = np.reshape(label_mask, (label_mask.shape[0], 1, label_mask.shape[1]))

segmentation_volume_tiled = np.tile(label_mask, (1, 128, 1))
segmentation_volume_mask = sp.round_x5_away_from_zero(zoom(segmentation_volume_tiled, input_spacing/spacing,
order=0)).astype(int)

if fuzzy:
segmentation_volume_mask = np.eye(C)[segmentation_volume_mask]
segmentation_volume_mask = np.moveaxis(segmentation_volume_mask, -1, 0)
segmentation_volume_mask = gaussian_filter(segmentation_volume_mask, sigma=1e-5, axes=(1, 2, 3)) # smooth the segmentation
segmentation_volume_mask /= segmentation_volume_mask.sum(axis=0, keepdims=True)


def segmentation_class_mapping():
ret_dict = dict()
ret_dict[0] = sp.TISSUE_LIBRARY.heavy_water()
Expand All @@ -68,14 +77,14 @@ def segmentation_class_mapping():
settings[Tags.RANDOM_SEED] = 1234
settings[Tags.WAVELENGTHS] = [700, 800]
settings[Tags.SPACING_MM] = spacing
settings[Tags.DIM_VOLUME_X_MM] = segmentation_volume_mask.shape[0] * spacing
settings[Tags.DIM_VOLUME_Y_MM] = segmentation_volume_mask.shape[1] * spacing
settings[Tags.DIM_VOLUME_Z_MM] = segmentation_volume_mask.shape[2] * spacing
x_dim_mm, y_dim_mm, z_dim_mm = segmentation_volume_mask.shape[-3:]
settings[Tags.DIM_VOLUME_X_MM] = x_dim_mm * spacing
settings[Tags.DIM_VOLUME_Y_MM] = y_dim_mm * spacing
settings[Tags.DIM_VOLUME_Z_MM] = z_dim_mm * spacing

settings.set_volume_creation_settings({
Tags.INPUT_SEGMENTATION_VOLUME: segmentation_volume_mask,
Tags.SEGMENTATION_CLASS_MAPPING: segmentation_class_mapping(),

})

settings.set_optical_settings({
Expand Down Expand Up @@ -108,9 +117,10 @@ def segmentation_class_mapping():
parser = ArgumentParser(description='Run the segmentation loader example')
parser.add_argument("--spacing", default=1, type=float, help='the voxel spacing in mm')
parser.add_argument("--input_spacing", default=0.2, type=float, help='the input spacing in mm')
parser.add_argument("--fuzzy", default=False, type=bool, help='whether to use fuzzy segmentation adapter')
parser.add_argument("--path_manager", default=None, help='the path manager, None uses sp.PathManager')
parser.add_argument("--visualise", default=True, type=bool, help='whether to visualise the result')
config = parser.parse_args()

run_segmentation_loader(spacing=config.spacing, input_spacing=config.input_spacing,
run_segmentation_loader(spacing=config.spacing, input_spacing=config.input_spacing, fuzzy=config.fuzzy,
path_manager=config.path_manager, visualise=config.visualise)

0 comments on commit 9ec96bc

Please sign in to comment.