Skip to content

Commit

Permalink
feat(transform-io): emscripten read_transform_async, write_transform_…
Browse files Browse the repository at this point in the history
…async
  • Loading branch information
thewtex committed Nov 25, 2024
1 parent 2ffc2ca commit 395beb7
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Generated file. To retain edits, remove this comment.

"""itkwasm-transform-io-emscripten: Input and output for scientific and medical coordinate transform file formats. Emscripten implementation."""

from .read_transform_async import read_transform_async
from .write_transform_async import write_transform_async

from .hdf5_read_transform_async import hdf5_read_transform_async
from .hdf5_write_transform_async import hdf5_write_transform_async
from .mat_read_transform_async import mat_read_transform_async
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from collections import OrderedDict

extension_to_transform_io = OrderedDict([
("h5", "hdf5"),
("hdf5", "hdf5"),
("txt", "txt"),
("mat", "mat"),
("xfm", "mnc"),
("iwt", "wasm"),
("iwt.cbor", "wasm"),
("iwt.cbor.zst", "wasmZstd"),
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
from typing import Optional, Union
from pathlib import Path

from itkwasm import (
TransformList,
BinaryFile,
)

from .js_package import js_package

from itkwasm.pyodide import (
to_js,
to_py,
js_resources
)

from .extension_to_transform_io import extension_to_transform_io
from .transform_io_index import transform_io_index

async def read_transform_async(
serialized_transform: os.PathLike,
float_parameters: bool = False,
) -> TransformList:
"""Read an transform file format and convert it to the itk-wasm file format
:param serialized_transform: Input transform serialized in the file format
:type serialized_transform: os.PathLike
:param float_parameters: Use float for the parameter value type. The default is double.
:type float_parameters: bool
:return: Output transform
:rtype: TransformList
"""
js_module = await js_package.js_module
web_worker = js_resources.web_worker

kwargs = {}
if float_parameters:
kwargs["floatParameters"] = to_js(float_parameters)

extension = ''.join(Path(serialized_transform).suffixes)

io = None
if extension in extension_to_transform_io:
func = f"{extension_to_transform_io[extension]}ReadTransform"
io = getattr(js_module, func)
else:
for ioname in transform_io_index:
func = f"{ioname}ReadTransform"
io = getattr(js_module, func)
outputs = await io(to_js(BinaryFile(serialized_transform)), webWorker=web_worker, noCopy=True, **kwargs)
outputs_object_map = outputs.as_object_map()
web_worker = outputs_object_map['webWorker']
js_resources.web_worker = web_worker
could_read = to_py(outputs_object_map['couldRead'])
if could_read:
transform = to_py(outputs_object_map['transform'])
return transform

if io is None:
raise RuntimeError(f"Could not find an transform reader for {extension}")

outputs = await io(to_js(BinaryFile(serialized_transform)), webWorker=web_worker, noCopy=True, **kwargs)
outputs_object_map = outputs.as_object_map()
web_worker = outputs_object_map['webWorker']
could_read = to_py(outputs_object_map['couldRead'])

if not could_read:
raise RuntimeError(f"Could not read {serialized_transform}")

js_resources.web_worker = web_worker

transform = to_py(outputs_object_map['transform'])

return transform

async def transformread_async(
serialized_transform: os.PathLike,
float_parameters: bool = False,
) -> TransformList:
return await read_transform_async(serialized_transform, float_parameters=float_parameters)

transformread_async.__doc__ = f"""{read_transform_async.__doc__}
Alias for read_transform_async.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
transform_io_index = [
'hdf5',
'mat',
'mnc',
'txt',
'wasm',
'wasm_ztd',
]

Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import importlib
from pathlib import Path
from typing import Optional, Union

from itkwasm import TransformList, PixelTypes, IntTypes, FloatTypes, BinaryFile

from itkwasm.pyodide import (
to_js,
to_py,
js_resources
)

from .js_package import js_package

from .extension_to_transform_io import extension_to_transform_io
from .transform_io_index import transform_io_index

async def write_transform_async(
transform: TransformList,
serialized_transform: os.PathLike,
float_parameters: bool = False,
use_compression: bool = False,
) -> None:
"""Write an itk-wasm TransformList to an transform file format.
:param transform: Input transform
:type transform: TransformList
:param serialized_transform: Output transform serialized in the file format.
:type serialized_transform: str
:param float_parameters: Use float for the parameter value type. The default is double.
:type float_parameters: bool
:param use_compression: Use compression in the written file
:type use_compression: bool
:param serialized_transform: Input transform serialized in the file format
:type serialized_transform: os.PathLike
"""
js_module = await js_package.js_module
web_worker = js_resources.web_worker

kwargs = {}
if float_parameters:
kwargs["floatParameters"] = to_js(float_parameters)
if use_compression:
kwargs["useCompression"] = to_js(use_compression)

extension = ''.join(Path(serialized_transform).suffixes)

io = None
if extension in extension_to_transform_io:
func = f"{extension_to_transform_io[extension]}WriteTransform"
io = getattr(js_module, func)
else:
for ioname in transform_io_index:
func = f"{ioname}WriteTransform"
io = getattr(js_module, func)
outputs = await io(to_js(transform), to_js(serialized_transform), webWorker=web_worker, noCopy=True, **kwargs)
outputs_object_map = outputs.as_object_map()
web_worker = outputs_object_map['webWorker']
js_resources.web_worker = web_worker
could_write = to_py(outputs_object_map['couldWrite'])
if could_write:
to_py(outputs_object_map['serializedTransform'])
return

if io is None:
raise RuntimeError(f"Could not find an transform writer for {extension}")

outputs = await io(to_js(transform), to_js(serialized_transform), webWorker=web_worker, noCopy=True, **kwargs)
outputs_object_map = outputs.as_object_map()
web_worker = outputs_object_map['webWorker']
js_resources.web_worker = web_worker
could_write = to_py(outputs_object_map['couldWrite'])

if not could_write:
raise RuntimeError(f"Could not write {serialized_transform}")

to_py(outputs_object_map['serializedTransform'])

async def transformwrite_async(
transform: TransformList,
serialized_transform: os.PathLike,
float_parameters: bool = False,
use_compression: bool = False,
) -> None:
return write_transform_async(transform, serialized_transform, float_parameters=float_parameters, use_compression=use_compression)

transformwrite_async.__doc__ = f"""{write_transform_async.__doc__}
Alias for write_transform.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys

if sys.version_info < (3,10):
pytest.skip("Skipping pyodide tests on older Python", allow_module_level=True)

import pytest
from pytest_pyodide import run_in_pyodide
from .fixtures import package_wheel, input_data

@pytest.mark.driver_timeout(30)
@run_in_pyodide(packages=['micropip', 'numpy'])
async def test_read_write_mesh_async(selenium, package_wheel, input_data):
import micropip
await micropip.install(package_wheel)
def write_input_data_to_fs(input_data, filename):
with open(filename, 'wb') as fp:
fp.write(input_data[filename])

from pathlib import Path

from itkwasm import TransformParameterizations, FloatTypes
import numpy as np

from itkwasm_transform_io_emscripten import read_transform_async, write_transform_async

def write_input_data_to_fs(input_data, filename):
with open(filename, 'wb') as fp:
fp.write(input_data[filename])

def verify_test_linear_transform(transform_list):
assert len(transform_list) == 1
transform = transform_list[0]
assert transform.transformType.transformParameterization == TransformParameterizations.Affine
assert transform.transformType.parametersValueType == FloatTypes.Float64
assert transform.numberOfParameters == 12
assert transform.numberOfFixedParameters == 3
np.testing.assert_allclose(transform.fixedParameters, np.array([0.0, 0.0, 0.0]))
np.testing.assert_allclose(transform.parameters, np.array([
0.65631490118447, 0.5806583745824385, -0.4817536741017158,
-0.7407986817430222, 0.37486398378429736, -0.5573995934598175,
-0.14306664045479867, 0.7227121458012518, 0.676179776908723,
-65.99999999999997, 69.00000000000004, 32.000000000000036]))

test_file_path = 'LinearTransform.h5'
write_input_data_to_fs(input_data, test_file_path)

assert Path(test_file_path).exists()

transform = await read_transform_async(test_file_path)
verify_test_linear_transform(transform)

test_output_file_path = 'out-LinearTransform.h5'

use_compression = False
await write_transform_async(transform, test_output_file_path, use_compression)

transform = await read_transform_async(test_output_file_path)
verify_test_linear_transform(transform)

0 comments on commit 395beb7

Please sign in to comment.