Skip to content

Commit

Permalink
feat(bindgen): Python output file support
Browse files Browse the repository at this point in the history
  • Loading branch information
thewtex committed Sep 28, 2023
1 parent e5ecfce commit 2cf3c7e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 29 deletions.
9 changes: 8 additions & 1 deletion src/bindgen/python/function-module-return-type.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@ function functionModuleReturnType(interfaceJson) {
jsonOutputs.forEach((value) => {
const canonical = canonicalType(value.type)
const pythonType = interfaceJsonTypeToPythonType.get(canonical)
returnType += `${pythonType}, `
if (value.type.includes('FILE')) {
return
}
if (value.itemsExpectedMax > 1) {
returnType += `List[${pythonType}], `
} else {
returnType += `${pythonType}, `
}
})
returnType = returnType.substring(0, returnType.length - 2)
returnType += "]"
Expand Down
140 changes: 114 additions & 26 deletions src/bindgen/python/wasi/wasi-function-module.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ import writeIfOverrideNotPresent from '../../write-if-override-not-present.js'

function wasiFunctionModule(interfaceJson, pypackage, modulePath) {
const functionName = snakeCase(interfaceJson.name)
let moduleContent = `# Generated file. Do not edit.
from pathlib import Path, PurePosixPath
let moduleContent = `from pathlib import Path, PurePosixPath
import os
from typing import Dict, Tuple, Optional, List, Any
Expand All @@ -34,20 +32,65 @@ from itkwasm import (
const returnType = functionModuleReturnType(interfaceJson)
const docstring = functionModuleDocstring(interfaceJson)

let pipelineOutputFilePrep = ''
interfaceJson.outputs.forEach((output) => {
if (interfaceJsonTypeToInterfaceType.has(output.type)) {
const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type)
const isArray = output.itemsExpectedMax > 1
if (interfaceType.includes('File')) {
const snake = snakeCase(output.name)
if (isArray) {
pipelineOutputFilePrep += ` ${snake}_pipeline_outputs = [PipelineOutput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(p))) for p in ${snake}]\n`
}
}
}
})

let pipelineOutputs = ''
let haveArray = false
interfaceJson.outputs.forEach((output) => {
if (interfaceJsonTypeToInterfaceType.has(output.type)) {
const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type)
const isArray = output.itemsExpectedMax > 1
switch (interfaceType) {
case "TextFile":
case "BinaryFile":
pipelineOutputs += ` PipelineOutput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(${snakeCase(output.name)}))),\n`
if (isArray) {
haveArray = true
pipelineOutputs += ` *${snakeCase(output.name)}_pipeline_outputs,\n`
} else {
pipelineOutputs += ` PipelineOutput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(${snakeCase(output.name)}))),\n`
}
break
default:
pipelineOutputs += ` PipelineOutput(InterfaceTypes.${interfaceType}),\n`
}
}
})
let pipelineOutputIndices = ''
if (haveArray) {
pipelineOutputIndices += ` output_index = 0\n`
interfaceJson.outputs.forEach((output) => {
if (interfaceJsonTypeToInterfaceType.has(output.type)) {
const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type)
if (interfaceType.includes('File')) {
const snake = snakeCase(output.name)
const isArray = output.itemsExpectedMax > 1
if (isArray) {
pipelineOutputIndices += ` ${snake}_start = output_index\n`
pipelineOutputIndices += ` output_index += len(${snake})\n`
pipelineOutputIndices += ` ${snake}_end = output_index\n`
} else {
pipelineOutputIndices += ` ${snake}_index = output_index\n`
pipelineOutputIndices += ` output_index += 1\n`
}
} else if (!interfaceType.includes('File')) {
pipelineOutputIndices += ` ${snake}_index = output_index\n`
pipelineOutputIndices += ` output_index += 1\n`
}
}
})
}

let pipelineInputs = ''
interfaceJson['inputs'].forEach((input) => {
Expand Down Expand Up @@ -90,18 +133,33 @@ from itkwasm import (
let outputCount = 0
args += " # Outputs\n"
interfaceJson.outputs.forEach((output) => {
const snake = snakeCase(output.name)
if (interfaceJsonTypeToInterfaceType.has(output.type)) {
const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type)
const name = interfaceType.includes('File') ? `str(PurePosixPath(${snakeCase(output.name)}))` : `'${outputCount.toString()}'`
args += ` args.append(${name})\n`
const isArray = output.itemsExpectedMax > 1
let name = ` ${snake}_name = '${outputCount.toString()}'\n`
if (interfaceType.includes('File')) {
if (isArray) {
name = ''
} else {
name = ` ${snake}_name = str(PurePosixPath(${snake}))\n`
}
}
args += name
if (isArray) {
args += ` args.extend([str(PurePosixPath(p)) for p in ${snake}])\n`
} else {
args += ` args.append(${snake}_name)\n`
}
args += '\n'
outputCount++
} else {
const snake = snakeCase(output.name)
args += ` args.append(str(${snake}))\n`
}
})

args += " # Options\n"
args += ` input_count = len(pipeline_inputs)\n`
interfaceJson.parameters.forEach((parameter) => {
if (parameter.name === 'memory-io' || parameter.name === 'version') {
// Internal
Expand All @@ -126,14 +184,14 @@ from itkwasm import (
args += ` args.append(input_file)\n`
} else if (interfaceType.includes('Stream')) {
// for streams
args += ` input_count_string = str(len(pipeline_inputs))\n`
args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(value)))\n`
args += ` args.append(input_count_spring)\n`
args += ` args.append(str(input_count))\n`
args += ` input_count += 1\n`
} else {
// Image, Mesh, PolyData, JsonCompatible
args += ` input_count_string = str(len(pipeline_inputs))\n`
args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, value))\n`
args += ` args.append(input_count_string)\n`
args += ` args.append(str(input_count))\n`
args += ` input_count += 1\n`
}
} else {
args += ` args.append(str(value))\n`
Expand All @@ -150,16 +208,16 @@ from itkwasm import (
args += ` args.append(input_file)\n`
} else if (interfaceType.includes('Stream')) {
// for streams
args += ` input_count_string = str(len(pipeline_inputs))\n`
args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(${snake})))\n`
args += ` args.append('--${parameter.name}')\n`
args += ` args.append(input_count_string)\n`
args += ` args.append(str(input_count))\n`
args += ` input_count += 1\n`
} else {
// Image, Mesh, PolyData, JsonCompatible
args += ` input_count_string = str(len(pipeline_inputs))\n`
args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${snake}))\n`
args += ` args.append('--${parameter.name}')\n`
args += ` args.append(input_count_string)\n`
args += ` args.append(str(input_count))\n`
args += ` input_count += 1\n`
}
} else {
args += ` if ${snake}:\n`
Expand Down Expand Up @@ -192,35 +250,65 @@ from itkwasm import (
case "float":
return `float(${value})`
case "Any":
return `${value}.data.data`
return `${value}.data`
default:
return `${value}.data`
}
}
outputCount = 0
const jsonOutputs = interfaceJson['outputs']
if (jsonOutputs.length > 1) {
const numOutputs = interfaceJson.outputs.filter(o => !o.type.includes('FILE')).length
if (numOutputs > 1) {
postOutput += ' result = (\n'
jsonOutputs.forEach((value, index) => {
const outputValue = `outputs[${index}]`
postOutput += ` ${toPythonType(value.type, outputValue)},\n`
} else if (numOutputs === 1){
postOutput = ' result = '
}

if (numOutputs > 0) {
// const outputValue = "outputs[0]"
// postOutput = ` result = ${toPythonType(jsonOutputs[0].type, outputValue)}\n`
const indent = numOutputs > 1 ? ' ' : ''
const comma = numOutputs > 1 ? ',' : ''
jsonOutputs.forEach((value) => {
if (value.type.includes('FILE')) {
outputCount++
return
}
const snake = snakeCase(value.name)
const outputIndex = haveArray ? `${snake}_index` : outputCount.toString()
if (haveArray) {
const isArray = value.itemsExpectedMax > 1
if (isArray) {
const outputValue = `outputs[${snake}_start:${snake}_end]`
postOutput += `${indent}*[${toPythonType(value.type, 'v')} for v in ${outputValue}]${comma}\n`
} else {
const outputValue = `outputs[${snake}_index]`
postOutput += `${indent}${toPythonType(value.type, outputValue)}${comma}\n`
}
} else {
const outputValue = `outputs[${outputIndex}]`
postOutput += `${indent}${toPythonType(value.type, outputValue)}${comma}\n`
}
outputCount++
})
}
if (numOutputs > 1) {
postOutput += ' )\n'
} else {
const outputValue = "outputs[0]"
postOutput = ` result = ${toPythonType(jsonOutputs[0].type, outputValue)}\n`
}
postOutput += ' return result\n'
if (numOutputs > 0) {
postOutput += ' return result\n'
}

moduleContent += `def ${functionName}(
${functionArgs}) -> ${returnType}:
${docstring}
global _pipeline
if _pipeline is None:
_pipeline = Pipeline(file_resources('${pypackage}').joinpath(Path('wasm_modules') / Path('${interfaceJson.name}.wasi.wasm')))
${pipelineOutputFilePrep}
pipeline_outputs: List[PipelineOutput] = [
${pipelineOutputs} ]
${pipelineOutputIndices}
pipeline_inputs: List[PipelineInput] = [
${pipelineInputs} ]
Expand Down
4 changes: 2 additions & 2 deletions src/bindgen/typescript/function-module.js
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,12 @@ function functionModule (srcOutputDir, forNode, interfaceJson, modulePascalCase,
if (interfaceJsonTypeToInterfaceType.has(output.type)) {
const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type)
let name = ` const ${camel}Name = '${outputCount.toString()}'\n`
const isArray = output.itemsExpectedMax > 1 ? '[]' : ''
const isArray = output.itemsExpectedMax > 1
if (interfaceType.includes('File')) {
if (isArray) {
name = ''
} else {
name = ` const ${camel}Name = ${camel}?? '${camel}'\n`
name = ` const ${camel}Name = ${camel}\n`
}
}
functionContent += name
Expand Down

0 comments on commit 2cf3c7e

Please sign in to comment.