Skip to content

Commit

Permalink
feat(Python): Add PolyData support
Browse files Browse the repository at this point in the history
  • Loading branch information
thewtex committed Feb 24, 2023
1 parent a55b916 commit 8299724
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 3 deletions.
129 changes: 127 additions & 2 deletions src/python/itkwasm/itkwasm/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from .binary_stream import BinaryStream
from .text_file import TextFile
from .binary_file import BinaryFile
from .image import Image, ImageType
from .mesh import Mesh, MeshType
from .image import Image
from .mesh import Mesh
from .polydata import PolyData
from .int_types import IntTypes
from .float_types import FloatTypes

Expand Down Expand Up @@ -166,6 +167,76 @@ def run(self, args: List[str], outputs: List[PipelineOutput]=[], inputs: List[Pi
"cellData": f"data:application/vnd.itk.address,0:{cell_data_ptr}",
}
self._set_input_json(mesh_json, index)
elif input_.type == InterfaceTypes.PolyData:
polydata = input_.data
if polydata.numberOfPoints:
pv = bytes(polydata.points)
else:
pv = bytes([])
points_ptr = self._set_input_array(pv, index, 0)

if polydata.verticesBufferSize:
pv = bytes(polydata.vertices)
else:
pv = bytes([])
vertices_ptr = self._set_input_array(pv, index, 1)

if polydata.linesBufferSize:
pv = bytes(polydata.lines)
else:
pv = bytes([])
lines_ptr = self._set_input_array(pv, index, 2)

if polydata.polygonsBufferSize:
pv = bytes(polydata.polygons)
else:
pv = bytes([])
polygons_ptr = self._set_input_array(pv, index, 3)

if polydata.triangleStripsBufferSize:
pv = bytes(polydata.triangleStrips)
else:
pv = bytes([])
triangleStrips_ptr = self._set_input_array(pv, index, 4)

if polydata.numberOfPointPixels:
pv = bytes(polydata.pointData)
else:
pv = bytes([])
pointData_ptr = self._set_input_array(pv, index, 5)

if polydata.numberOfCellPixels:
pv = bytes(polydata.cellData)
else:
pv = bytes([])
cellData_ptr = self._set_input_array(pv, index, 6)

polydata_json = {
"polyDataType": asdict(polydata.polyDataType),
"name": polydata.name,

"numberOfPoints": polydata.numberOfPoints,
"points": f"data:application/vnd.itk.address,0:{points_ptr}",

"verticesBufferSize": polydata.verticesBufferSize,
"vertices": f"data:application/vnd.itk.address,0:{vertices_ptr}",

"linesBufferSize": polydata.linesBufferSize,
"lines": f"data:application/vnd.itk.address,0:{lines_ptr}",

"polygonsBufferSize": polydata.polygonsBufferSize,
"polygons": f"data:application/vnd.itk.address,0:{polygons_ptr}",

"triangleStripsBufferSize": polydata.triangleStripsBufferSize,
"triangleStrips": f"data:application/vnd.itk.address,0:{triangleStrips_ptr}",

"numberOfPointPixels": polydata.numberOfPointPixels,
"pointData": f"data:application/vnd.itk.address,0:{pointData_ptr}",

"numberOfCellPixels": polydata.numberOfCellPixels,
"cellData": f"data:application/vnd.itk.address,0:{cellData_ptr}"
}
self._set_input_json(polydata_json, index)
else:
raise ValueError(f'Unexpected/not yet supported input.type {input_.type}')

Expand Down Expand Up @@ -240,6 +311,60 @@ def run(self, args: List[str], outputs: List[PipelineOutput]=[], inputs: List[Pi
mesh.cellData = _memoryview_to_numpy_array(mesh.meshType.cellPixelComponentType, bytes([]))

output_data = PipelineOutput(InterfaceTypes.Mesh, mesh)
elif output.type == InterfaceTypes.PolyData:
polydata_json = self._get_output_json(index)
polydata = PolyData(**polydata_json)

if polydata.numberOfPoints > 0:
data_ptr = self.output_array_address(0, index, 0)
data_size = self.output_array_size(0, index, 0)
polydata.points = _memoryview_to_numpy_array(FloatTypes.Float32, memoryview(self.memory.buffer)[data_ptr:data_ptr+data_size])
else:
polydata.points = _memoryview_to_numpy_array(FloatTypes.Float32, bytes([]))

if polydata.verticesBufferSize > 0:
data_ptr = self.output_array_address(0, index, 1)
data_size = self.output_array_size(0, index, 1)
polydata.vertices = _memoryview_to_numpy_array(IntTypes.UInt32, memoryview(self.memory.buffer)[data_ptr:data_ptr+data_size])
else:
polydata.vertices = _memoryview_to_numpy_array(IntTypes.UInt32, bytes([]))

if polydata.linesBufferSize > 0:
data_ptr = self.output_array_address(0, index, 2)
data_size = self.output_array_size(0, index, 2)
polydata.lines = _memoryview_to_numpy_array(IntTypes.UInt32, memoryview(self.memory.buffer)[data_ptr:data_ptr+data_size])
else:
polydata.lines = _memoryview_to_numpy_array(IntTypes.UInt32, bytes([]))

if polydata.polygonsBufferSize > 0:
data_ptr = self.output_array_address(0, index, 3)
data_size = self.output_array_size(0, index, 3)
polydata.polygons = _memoryview_to_numpy_array(IntTypes.UInt32, memoryview(self.memory.buffer)[data_ptr:data_ptr+data_size])
else:
polydata.polygons = _memoryview_to_numpy_array(IntTypes.UInt32, bytes([]))

if polydata.triangleStripsBufferSize > 0:
data_ptr = self.output_array_address(0, index, 4)
data_size = self.output_array_size(0, index, 4)
polydata.triangleStrips = _memoryview_to_numpy_array(IntTypes.UInt32, memoryview(self.memory.buffer)[data_ptr:data_ptr+data_size])
else:
polydata.triangleStrips = _memoryview_to_numpy_array(IntTypes.UInt32, bytes([]))

if polydata.numberOfPointPixels > 0:
data_ptr = self.output_array_address(0, index, 5)
data_size = self.output_array_size(0, index, 5)
polydata.pointData = _memoryview_to_numpy_array(polydata.polyDataType.pointPixelComponentType, memoryview(self.memory.buffer)[data_ptr:data_ptr+data_size])
else:
polydata.triangleStrips = _memoryview_to_numpy_array(polydata.polyDataType.pointPixelComponentType, bytes([]))

if polydata.numberOfCellPixels > 0:
data_ptr = self.output_array_address(0, index, 6)
data_size = self.output_array_size(0, index, 6)
polydata.cellData = _memoryview_to_numpy_array(polydata.polyDataType.cellPixelComponentType, memoryview(self.memory.buffer)[data_ptr:data_ptr+data_size])
else:
polydata.triangleStrips = _memoryview_to_numpy_array(polydata.polyDataType.cellPixelComponentType, bytes([]))

output_data = PipelineOutput(InterfaceTypes.PolyData, polydata)

populated_outputs.append(output_data)

Expand Down
Binary file not shown.
Binary file not shown.
52 changes: 51 additions & 1 deletion src/python/itkwasm/test/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_pipeline_write_read_image():
outputs = pipeline.run(args, pipeline_outputs, pipeline_inputs)

out_image = itk.image_from_dict(asdict(outputs[0].data))
# To be addresses in itk-5.3.1
# To be addressed in itk-5.3.1
out_image.SetRegions([256,256])

baseline = itk.imread(test_baseline_dir / "test_pipeline_write_read_image.png")
Expand Down Expand Up @@ -157,3 +157,53 @@ def test_pipeline_write_read_mesh():

assert out_mesh.GetNumberOfPoints() == 2903
assert out_mesh.GetNumberOfCells() == 3263

def test_pipeline_write_read_polydata():
pipeline = Pipeline(test_input_dir / 'mesh-to-poly-data.wasi.wasm')

data = test_input_dir / "cow.vtk"
itk_mesh = itk.meshread(data)
itk_mesh_dict = itk.dict_from_mesh(itk_mesh)
itkwasm_mesh = Mesh(**itk_mesh_dict)

pipeline_inputs = [
PipelineInput(InterfaceTypes.Mesh, itkwasm_mesh),
]

pipeline_outputs = [
PipelineOutput(InterfaceTypes.PolyData),
]

args = [
'0',
'0',
'--memory-io',]

outputs = pipeline.run(args, pipeline_outputs, pipeline_inputs)
polydata = outputs[0].data

pipeline = Pipeline(test_input_dir / 'poly-data-to-mesh.wasi.wasm')

pipeline_inputs = [
PipelineInput(InterfaceTypes.PolyData, polydata),
]

pipeline_outputs = [
PipelineOutput(InterfaceTypes.Mesh),
]

args = [
'0',
'0',
'--memory-io',]

outputs = pipeline.run(args, pipeline_outputs, pipeline_inputs)

out_mesh_dict = asdict(outputs[0].data)
# native itk python binaries require uint64
out_mesh_dict['cells'] = out_mesh_dict['cells'].astype(np.uint64)
out_mesh_dict['meshType']['cellComponentType'] = 'uint64'
out_mesh = itk.mesh_from_dict(out_mesh_dict)

assert out_mesh.GetNumberOfPoints() == 2903
assert out_mesh.GetNumberOfCells() == 3263

0 comments on commit 8299724

Please sign in to comment.