Skip to content

Commit

Permalink
feat(python): add cast_image function to itkwasm package
Browse files Browse the repository at this point in the history
  • Loading branch information
thewtex committed Aug 21, 2023
1 parent c9abf86 commit 5e0c127
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 2 deletions.
4 changes: 3 additions & 1 deletion packages/core/python/itkwasm/itkwasm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""itkwasm: Python interface to itk-wasm WebAssembly modules."""

__version__ = "1.0b105"
__version__ = "1.0b128"

from .interface_types import InterfaceTypes
from .image import Image, ImageType
Expand All @@ -19,6 +19,7 @@
from .int_types import IntTypes
from .pixel_types import PixelTypes
from .environment_dispatch import environment_dispatch, function_factory
from .cast_image import cast_image

__all__ = [
"InterfaceTypes",
Expand All @@ -43,4 +44,5 @@
"PixelTypes",
"environment_dispatch",
"function_factory",
"cast_image",
]
63 changes: 63 additions & 0 deletions packages/core/python/itkwasm/itkwasm/cast_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Optional
from dataclasses import asdict
import copy

import numpy as np

from .image import Image, ImageType
from .pixel_types import PixelTypes
from .int_types import IntTypes
from .float_types import FloatTypes

def cast_image(input_image: Image,
pixel_type: Optional[PixelTypes]=None,
component_type: Optional[IntTypes | FloatTypes]=None) -> Image:
"""Cast an image to another pixel type and / or component type."""
output_image_type = ImageType(**asdict(input_image.imageType))

if pixel_type is not None:
output_image_type.pixelType = pixel_type
if pixel_type == PixelTypes.Scalar and output_image_type.components != 1:
raise ValueError("PixelType Scalar requires components == 1")

if component_type is not None and component_type != input_image.imageType.componentType:
output_image_type.componentType = component_type

output_image = Image(output_image_type)

output_image.name = input_image.name
output_image.origin = list(input_image.origin)
output_image.spacing = list(input_image.spacing)
output_image.direction = input_image.direction.copy()
output_image.size = list(input_image.size)
output_image.metadata = copy.deepcopy(input_image.metadata)

if input_image.data is not None:
if output_image_type.componentType == input_image.imageType.componentType:
output_image.data = input_image.data.copy()
else:
component_type = output_image_type.componentType
if component_type == IntTypes.UInt8:
output_image.data = input_image.data.astype(np.uint8)
elif component_type == IntTypes.Int8:
output_image.data = input_image.data.astype(np.int8)
elif component_type == IntTypes.UInt16:
output_image.data = input_image.data.astype(np.uint16)
elif component_type == IntTypes.Int16:
output_image.data = input_image.data.astype(np.int16)
elif component_type == IntTypes.UInt32:
output_image.data = input_image.data.astype(np.uint32)
elif component_type == IntTypes.Int32:
output_image.data = input_image.data.astype(np.int32)
elif component_type == IntTypes.UInt64:
output_image.data = input_image.data.astype(np.uint64)
elif component_type == IntTypes.Int64:
output_image.data = input_image.data.astype(np.int64)
elif component_type == FloatTypes.Float32:
output_image.data = input_image.data.astype(np.float32)
elif component_type == FloatTypes.Float64:
output_image.data = input_image.data.astype(np.float64)
else:
raise ValueError('Unsupported component type')

return output_image
29 changes: 29 additions & 0 deletions packages/core/python/itkwasm/test/test_cast_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pathlib import Path

import itk

from itkwasm import Image
from itkwasm import FloatTypes, PixelTypes
import numpy as np

from itkwasm import cast_image

def test_cast_image_component():
data = Path(__file__).absolute().parent / "input" / "cthead1.png"
itk_image = itk.imread(data, itk.UC)
itk_image_dict = itk.dict_from_image(itk_image)
itkwasm_image = Image(**itk_image_dict)

itkwasm_image_double = cast_image(itkwasm_image, component_type=FloatTypes.Float64)
assert itkwasm_image_double.imageType.componentType == FloatTypes.Float64
assert np.array_equal(itkwasm_image.data, itkwasm_image_double.data)

def test_cast_image_pixel_type():
data = Path(__file__).absolute().parent / "input" / "cthead1.png"
itk_image = itk.imread(data, itk.UC)
itk_image_dict = itk.dict_from_image(itk_image)
itkwasm_image = Image(**itk_image_dict)

itkwasm_image_vector = cast_image(itkwasm_image, pixel_type=PixelTypes.VariableLengthVector)
assert itkwasm_image_vector.imageType.pixelType == PixelTypes.VariableLengthVector
assert np.array_equal(itkwasm_image.data, itkwasm_image_vector.data)
2 changes: 1 addition & 1 deletion src/core/castImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import IntTypes from './IntTypes.js'
import FloatTypes from './FloatTypes.js'

/**
* Cast an image to another PixelType or ComponentType
* Cast an image to another PixelType and/or ComponentType
*
* @param {Image} image - The input image
* @param {CastImageOptions} options - specify the componentType and/or pixelType of the output
Expand Down

0 comments on commit 5e0c127

Please sign in to comment.