From 2cf3c7eb68e8a6bea710467a2f1783ba9833beaa Mon Sep 17 00:00:00 2001 From: Matt McCormick Date: Wed, 27 Sep 2023 21:34:47 -0400 Subject: [PATCH] feat(bindgen): Python output file support --- .../python/function-module-return-type.js | 9 +- .../python/wasi/wasi-function-module.js | 140 ++++++++++++++---- src/bindgen/typescript/function-module.js | 4 +- 3 files changed, 124 insertions(+), 29 deletions(-) diff --git a/src/bindgen/python/function-module-return-type.js b/src/bindgen/python/function-module-return-type.js index 44ee43b8d..407dae055 100644 --- a/src/bindgen/python/function-module-return-type.js +++ b/src/bindgen/python/function-module-return-type.js @@ -9,7 +9,14 @@ function functionModuleReturnType(interfaceJson) { jsonOutputs.forEach((value) => { const canonical = canonicalType(value.type) const pythonType = interfaceJsonTypeToPythonType.get(canonical) - returnType += `${pythonType}, ` + if (value.type.includes('FILE')) { + return + } + if (value.itemsExpectedMax > 1) { + returnType += `List[${pythonType}], ` + } else { + returnType += `${pythonType}, ` + } }) returnType = returnType.substring(0, returnType.length - 2) returnType += "]" diff --git a/src/bindgen/python/wasi/wasi-function-module.js b/src/bindgen/python/wasi/wasi-function-module.js index d7caed5ed..934a6a0d1 100644 --- a/src/bindgen/python/wasi/wasi-function-module.js +++ b/src/bindgen/python/wasi/wasi-function-module.js @@ -13,9 +13,7 @@ import writeIfOverrideNotPresent from '../../write-if-override-not-present.js' function wasiFunctionModule(interfaceJson, pypackage, modulePath) { const functionName = snakeCase(interfaceJson.name) - let moduleContent = `# Generated file. Do not edit. - -from pathlib import Path, PurePosixPath + let moduleContent = `from pathlib import Path, PurePosixPath import os from typing import Dict, Tuple, Optional, List, Any @@ -34,20 +32,65 @@ from itkwasm import ( const returnType = functionModuleReturnType(interfaceJson) const docstring = functionModuleDocstring(interfaceJson) + let pipelineOutputFilePrep = '' + interfaceJson.outputs.forEach((output) => { + if (interfaceJsonTypeToInterfaceType.has(output.type)) { + const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) + const isArray = output.itemsExpectedMax > 1 + if (interfaceType.includes('File')) { + const snake = snakeCase(output.name) + if (isArray) { + pipelineOutputFilePrep += ` ${snake}_pipeline_outputs = [PipelineOutput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(p))) for p in ${snake}]\n` + } + } + } + }) + let pipelineOutputs = '' + let haveArray = false interfaceJson.outputs.forEach((output) => { if (interfaceJsonTypeToInterfaceType.has(output.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) + const isArray = output.itemsExpectedMax > 1 switch (interfaceType) { case "TextFile": case "BinaryFile": - pipelineOutputs += ` PipelineOutput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(${snakeCase(output.name)}))),\n` + if (isArray) { + haveArray = true + pipelineOutputs += ` *${snakeCase(output.name)}_pipeline_outputs,\n` + } else { + pipelineOutputs += ` PipelineOutput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(${snakeCase(output.name)}))),\n` + } break default: pipelineOutputs += ` PipelineOutput(InterfaceTypes.${interfaceType}),\n` } } }) + let pipelineOutputIndices = '' + if (haveArray) { + pipelineOutputIndices += ` output_index = 0\n` + interfaceJson.outputs.forEach((output) => { + if (interfaceJsonTypeToInterfaceType.has(output.type)) { + const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) + if (interfaceType.includes('File')) { + const snake = snakeCase(output.name) + const isArray = output.itemsExpectedMax > 1 + if (isArray) { + pipelineOutputIndices += ` ${snake}_start = output_index\n` + pipelineOutputIndices += ` output_index += len(${snake})\n` + pipelineOutputIndices += ` ${snake}_end = output_index\n` + } else { + pipelineOutputIndices += ` ${snake}_index = output_index\n` + pipelineOutputIndices += ` output_index += 1\n` + } + } else if (!interfaceType.includes('File')) { + pipelineOutputIndices += ` ${snake}_index = output_index\n` + pipelineOutputIndices += ` output_index += 1\n` + } + } + }) + } let pipelineInputs = '' interfaceJson['inputs'].forEach((input) => { @@ -90,18 +133,33 @@ from itkwasm import ( let outputCount = 0 args += " # Outputs\n" interfaceJson.outputs.forEach((output) => { + const snake = snakeCase(output.name) if (interfaceJsonTypeToInterfaceType.has(output.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) - const name = interfaceType.includes('File') ? `str(PurePosixPath(${snakeCase(output.name)}))` : `'${outputCount.toString()}'` - args += ` args.append(${name})\n` + const isArray = output.itemsExpectedMax > 1 + let name = ` ${snake}_name = '${outputCount.toString()}'\n` + if (interfaceType.includes('File')) { + if (isArray) { + name = '' + } else { + name = ` ${snake}_name = str(PurePosixPath(${snake}))\n` + } + } + args += name + if (isArray) { + args += ` args.extend([str(PurePosixPath(p)) for p in ${snake}])\n` + } else { + args += ` args.append(${snake}_name)\n` + } + args += '\n' outputCount++ } else { - const snake = snakeCase(output.name) args += ` args.append(str(${snake}))\n` } }) args += " # Options\n" + args += ` input_count = len(pipeline_inputs)\n` interfaceJson.parameters.forEach((parameter) => { if (parameter.name === 'memory-io' || parameter.name === 'version') { // Internal @@ -126,14 +184,14 @@ from itkwasm import ( args += ` args.append(input_file)\n` } else if (interfaceType.includes('Stream')) { // for streams - args += ` input_count_string = str(len(pipeline_inputs))\n` args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(value)))\n` - args += ` args.append(input_count_spring)\n` + args += ` args.append(str(input_count))\n` + args += ` input_count += 1\n` } else { // Image, Mesh, PolyData, JsonCompatible - args += ` input_count_string = str(len(pipeline_inputs))\n` args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, value))\n` - args += ` args.append(input_count_string)\n` + args += ` args.append(str(input_count))\n` + args += ` input_count += 1\n` } } else { args += ` args.append(str(value))\n` @@ -150,16 +208,16 @@ from itkwasm import ( args += ` args.append(input_file)\n` } else if (interfaceType.includes('Stream')) { // for streams - args += ` input_count_string = str(len(pipeline_inputs))\n` args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(${snake})))\n` args += ` args.append('--${parameter.name}')\n` - args += ` args.append(input_count_string)\n` + args += ` args.append(str(input_count))\n` + args += ` input_count += 1\n` } else { // Image, Mesh, PolyData, JsonCompatible - args += ` input_count_string = str(len(pipeline_inputs))\n` args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${snake}))\n` args += ` args.append('--${parameter.name}')\n` - args += ` args.append(input_count_string)\n` + args += ` args.append(str(input_count))\n` + args += ` input_count += 1\n` } } else { args += ` if ${snake}:\n` @@ -192,24 +250,54 @@ from itkwasm import ( case "float": return `float(${value})` case "Any": - return `${value}.data.data` + return `${value}.data` default: return `${value}.data` } } + outputCount = 0 const jsonOutputs = interfaceJson['outputs'] - if (jsonOutputs.length > 1) { + const numOutputs = interfaceJson.outputs.filter(o => !o.type.includes('FILE')).length + if (numOutputs > 1) { postOutput += ' result = (\n' - jsonOutputs.forEach((value, index) => { - const outputValue = `outputs[${index}]` - postOutput += ` ${toPythonType(value.type, outputValue)},\n` + } else if (numOutputs === 1){ + postOutput = ' result = ' + } + + if (numOutputs > 0) { + // const outputValue = "outputs[0]" + // postOutput = ` result = ${toPythonType(jsonOutputs[0].type, outputValue)}\n` + const indent = numOutputs > 1 ? ' ' : '' + const comma = numOutputs > 1 ? ',' : '' + jsonOutputs.forEach((value) => { + if (value.type.includes('FILE')) { + outputCount++ + return + } + const snake = snakeCase(value.name) + const outputIndex = haveArray ? `${snake}_index` : outputCount.toString() + if (haveArray) { + const isArray = value.itemsExpectedMax > 1 + if (isArray) { + const outputValue = `outputs[${snake}_start:${snake}_end]` + postOutput += `${indent}*[${toPythonType(value.type, 'v')} for v in ${outputValue}]${comma}\n` + } else { + const outputValue = `outputs[${snake}_index]` + postOutput += `${indent}${toPythonType(value.type, outputValue)}${comma}\n` + } + } else { + const outputValue = `outputs[${outputIndex}]` + postOutput += `${indent}${toPythonType(value.type, outputValue)}${comma}\n` + } + outputCount++ }) + } + if (numOutputs > 1) { postOutput += ' )\n' - } else { - const outputValue = "outputs[0]" - postOutput = ` result = ${toPythonType(jsonOutputs[0].type, outputValue)}\n` } - postOutput += ' return result\n' + if (numOutputs > 0) { + postOutput += ' return result\n' + } moduleContent += `def ${functionName}( ${functionArgs}) -> ${returnType}: @@ -217,10 +305,10 @@ ${functionArgs}) -> ${returnType}: global _pipeline if _pipeline is None: _pipeline = Pipeline(file_resources('${pypackage}').joinpath(Path('wasm_modules') / Path('${interfaceJson.name}.wasi.wasm'))) - +${pipelineOutputFilePrep} pipeline_outputs: List[PipelineOutput] = [ ${pipelineOutputs} ] - +${pipelineOutputIndices} pipeline_inputs: List[PipelineInput] = [ ${pipelineInputs} ] diff --git a/src/bindgen/typescript/function-module.js b/src/bindgen/typescript/function-module.js index 09700fc13..2666eec0a 100644 --- a/src/bindgen/typescript/function-module.js +++ b/src/bindgen/typescript/function-module.js @@ -306,12 +306,12 @@ function functionModule (srcOutputDir, forNode, interfaceJson, modulePascalCase, if (interfaceJsonTypeToInterfaceType.has(output.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) let name = ` const ${camel}Name = '${outputCount.toString()}'\n` - const isArray = output.itemsExpectedMax > 1 ? '[]' : '' + const isArray = output.itemsExpectedMax > 1 if (interfaceType.includes('File')) { if (isArray) { name = '' } else { - name = ` const ${camel}Name = ${camel}?? '${camel}'\n` + name = ` const ${camel}Name = ${camel}\n` } } functionContent += name