diff --git a/ONNX2MPSX.py b/ONNX2MPSX.py index 06ccd3c..0e3b9f8 100644 --- a/ONNX2MPSX.py +++ b/ONNX2MPSX.py @@ -50,6 +50,7 @@ def convert_onnx_to_mpsx(model, halfs): onnx.helper.make_graph( nodes=model.graph.node, name=model.graph.name, + value_info=model.graph.value_info, inputs=_fp32_to_fp16_info( model.graph.input) if halfs else model.graph.input, outputs=_fp32_to_fp16_info( diff --git a/Sources/MPSX/Foundation/DSL.swift b/Sources/MPSX/Foundation/DSL.swift index bf95460..47d7f25 100644 --- a/Sources/MPSX/Foundation/DSL.swift +++ b/Sources/MPSX/Foundation/DSL.swift @@ -412,8 +412,8 @@ public extension MPSGraphTensor { self, size: [height, width].nsnumbers, mode: mode, - centerResult: true, - alignCorners: false, + centerResult: false, + alignCorners: true, layout: layout, name: nil ) diff --git a/Sources/MPSX/Foundation/Utilities.swift b/Sources/MPSX/Foundation/Utilities.swift index 68a966b..1eef1fc 100644 --- a/Sources/MPSX/Foundation/Utilities.swift +++ b/Sources/MPSX/Foundation/Utilities.swift @@ -468,3 +468,11 @@ public extension MPSGraphTensor { shape?.map(\.intValue) ?? [] } } + +public extension MPSCommandBuffer { + func commitAndWait() { + let _rootCommandBuffer = rootCommandBuffer + commitAndContinue() + _rootCommandBuffer.waitUntilCompleted() + } +} diff --git a/Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift b/Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift new file mode 100644 index 0000000..f40c07e --- /dev/null +++ b/Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift @@ -0,0 +1,31 @@ +import MetalPerformanceShadersGraph + +public protocol OnnxAttributesProvider { + func attr(s name: String) -> String + func attr(f name: String) -> Float? + func attr(floats name: String) -> [Float]? + func attr(i name: String) -> Int? + func attr(ints name: String) -> [Int]? +} + +public protocol OnnxCustomNode { + func preprocess( + inputTensor: MPSGraphTensor, + inputName: String, + graph: MPSGraph + ) -> MPSGraphTensor + + func postprocess( + outputName: String, + outputShape: [Int], + requiredDataType: MPSDataType, + graph: MPSGraph + ) -> (placeholder: MPSGraphTensor, tensor: MPSGraphTensor) + + func eval( + inputs: [MPSGraphTensorData], + outputShapes: [[Int]], + attributesProvider: OnnxAttributesProvider, + in commandBuffer: MPSCommandBuffer + ) throws -> [MPSGraphTensorData] +} diff --git a/Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift b/Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift new file mode 100644 index 0000000..c5420b7 --- /dev/null +++ b/Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift @@ -0,0 +1,80 @@ +extension OnnxModel { + /// Split the graph into subgraphs. + /// - Parameter splitOps: The names of the operators in the graph by which you want to split + /// - Returns: A map where key is the name of the node (node opType exists in splitOps) and value is an array of input tensors that are "alive" at the time of the split. + func split(by splitOps: Set) -> [String: Set] { + struct NodeCounter { + /// How many times the output of a node will be used as an input + var readCount: Int + /// Node lifetime: lower bound == first occurrence in the graph, upper bound == last occurrence as input in the graph. + var range: ClosedRange + } + + let graph = proto.graph + + var counters: [String: NodeCounter] = [:] + counters.reserveCapacity(graph.node.count) + + var nodeIndices: [String: Int] = [:] + nodeIndices.reserveCapacity(graph.node.count) + + // STEP 1: calculate counters for every node from start to end + + for (index, node) in graph.node.enumerated() { + for outputName in node.output { + nodeIndices[outputName] = index + } + + for inputName in node.input { + // only runtime tensors + + guard initializer[inputName] == nil else { + continue + } + + let inputIndex = nodeIndices[inputName, default: 0] + + let counter = counters[inputName, default: .init(readCount: 0, range: inputIndex ... inputIndex)] + + counters[inputName] = .init( + readCount: counter.readCount + 1, + range: counter.range.lowerBound ... index + ) + } + } + + // STEP 2: now we have counters for every node and can split graph into subgraphs using this info. + + var subgraphs: [String: Set] = [:] + + for (index, node) in graph.node.enumerated() { + for inputName in node.input { + guard initializer[inputName] == nil, + var counter = counters[inputName] + else { + continue + } + + if counter.readCount > 0 { + counter.readCount -= 1 + + counters[inputName] = counter + } else { + counters[inputName] = nil + } + } + + // split + + if splitOps.contains(node.opType) { + // TODO: optimize using some tree data structure/sorted collection + + let inputs = counters.filter { $0.value.range.contains(index) } + + subgraphs[node.name] = Set(inputs.map(\.key)).subtracting(Set(node.output)) + } + } + + return subgraphs + } +} diff --git a/Sources/MPSX/ONNX/Experimental/OnnxPipeline.swift b/Sources/MPSX/ONNX/Experimental/OnnxPipeline.swift new file mode 100644 index 0000000..cc799c9 --- /dev/null +++ b/Sources/MPSX/ONNX/Experimental/OnnxPipeline.swift @@ -0,0 +1,410 @@ +import MetalPerformanceShadersGraph + +extension Onnx_NodeProto: OnnxAttributesProvider {} + +/// Experimental version of OnnxGraph with support for custom implementations for unknown operators +public final class OnnxPipeline { + // MARK: Lifecycle + + /// Initialize pipeline instance + /// - Parameters: + /// - model: onnx model + /// - device: metal device for graph compilation + /// - customNodes: user-provided implementations for unknown operations + /// - config: graph building configuration + public init( + model: OnnxModel, + device: MTLDevice, + customNodes: [String: OnnxCustomNode], + config: OnnxGraphConfig = .init() + ) throws { + let onnxGraph = model.proto.graph + + // required for custom operations + + let valueInfo: [String: Onnx_TypeProto.Tensor] = onnxGraph.valueInfo.reduce(into: [:]) { + if case let .tensorType(tensor) = $1.type.value { + $0[$1.name] = tensor // this object contains information about shape, datatype, etc. + } + } + + let splitResult = customNodes.isEmpty ? [:] : model.split(by: Set(customNodes.keys)) + + let populateTensorsOnTheFly = !splitResult.isEmpty + + let graphOptions: MPSGraphOptions = .none + let compilationDescriptor: MPSGraphCompilationDescriptor? = nil + let tensorsDataType = config.tensorsDataType.mpsDataType // fp16 or fp32 + + var constants = model.initializer // weights, biases, etc. + + var pipeline: [PipelineStep] = [] + + var mpsGraph = MPSGraph(options: graphOptions) // current mpsgraph -> will be overwritten if split result is not empty + var mpsTensors: [String: MPSGraphTensor] // table of current mpsgraph tensors for faster lookup by name + + if populateTensorsOnTheFly { + mpsTensors = [:] + } else { + mpsTensors = try model.initializer.mapValues { + try mpsGraph.constant($0, targetDataType: tensorsDataType) + } + } + + // create placeholders + + for input in model.inputs { + let options = config.inputs[input.name] + + let shape = input.shape.enumerated().map { + options?.dims?[$0.offset] ?? Int($0.element) + } + + mpsTensors[input.name] = mpsGraph.input( + shape: shape.nsnumbers, + dataType: tensorsDataType, + valuesRange: options?.valuesRange, + name: input.name + ) + } + + // iterate over graph: onnx guarantees graph nodes are sorted topologically + + for node in onnxGraph.node { + if populateTensorsOnTheFly { + for input in node.input { + if let tensor = model.initializer[input], mpsTensors[input] == nil { + mpsTensors[input] = try mpsGraph.constant(tensor, targetDataType: tensorsDataType) + } + } + } + + let success = try mpsGraph.onnx( + node: node, + optimizedForMPS: model.optimizedForMPS, + tensorsDataType: tensorsDataType, + tensors: &mpsTensors, + constants: &constants + ) + + if success { + continue // known operation + } + + // lookup for a user-provided implementation + + guard let customNode = customNodes[node.opType], + let subgraphOutputs = splitResult[node.name] + else { + throw OnnxError.unsupportedOperator(node.opType) // custom implementation not found + } + + // generate node inputs using current mpsgraph instance and tensor table + + try node.input.forEach { + guard let inputTensor = mpsTensors[$0] else { + throw OnnxError.invalidModel(reason: "Tensor named \($0) not found") + } + + mpsTensors[$0] = customNode.preprocess( + inputTensor: inputTensor, + inputName: $0, + graph: mpsGraph + ) + } + + // finalize current mps graph + + let outputTensors = subgraphOutputs.reduce(into: [:]) { + $0[$1] = mpsTensors[$1] + } + + let subgraph = MPSCompiledGraph( + compilationDescriptor: compilationDescriptor, + device: device, + graph: mpsGraph, + outputTensors: outputTensors + ) + + // create next mpsgraph instance and tensor table + + mpsGraph = .init(options: graphOptions) + mpsTensors = [:] + + var customNames: [String: String] = [:] + + outputTensors.forEach { + let placeholder = mpsGraph.placeholder(shape: $1.shape, dataType: $1.dataType, name: $0) + + mpsTensors[$0] = placeholder + + // ⚠️ mps can change the user-defined name: ex. dots (.) will be replaced with underscores (_) + + if placeholder.operation.name != $0 { + customNames[$0] = placeholder.operation.name + } + } + + // generate node outputs and placeholders + + let outputShapes = try node.output.map { + guard let tensorInfo = valueInfo[$0], tensorInfo.hasShape else { + throw OnnxError.invalidModel(reason: "Shaped tensor named \($0) not found") + } + + let shape = tensorInfo.shape.dim.map { Int($0.dimValue) } + + let (placeholder, tensor) = customNode.postprocess( + outputName: $0, + outputShape: shape, + requiredDataType: tensorsDataType, + graph: mpsGraph + ) + + if placeholder.operation.name != $0 { + customNames[$0] = placeholder.operation.name + } + + guard (tensor.shape ?? []).map(\.intValue) == shape else { + throw OnnxError.incorrectCustomNodeImplementation( + opType: node.opType, + reason: "Shape of tensor named \($0) does not match the required \(shape)" + ) + } + + mpsTensors[$0] = tensor + + return shape + } + + // pipeline++ + + pipeline.append(.graph(subgraph)) + pipeline.append(.custom(.init(proto: node, outputShapes: outputShapes, customNames: customNames))) + } + + // final step: setup onnx graph outputs + + let finalGraph = try MPSCompiledGraph( + compilationDescriptor: compilationDescriptor, + device: device, + graph: mpsGraph, + outputTensors: model.outputs.reduce(into: [:]) { + guard let tensor = mpsTensors[$1] else { + throw OnnxError.invalidModel(reason: "Tensor named \($1) not found") + } + + $0[$1] = mpsGraph.output( + tensor: tensor, + valuesRange: config.outputs[$1]?.valuesRange + ) + } + ) + + pipeline.append(.graph(finalGraph)) + + firstGraph = pipeline.first!.graph! + lastGraph = pipeline.last!.graph! + + self.pipeline = pipeline + self.customNodes = customNodes + } + + // MARK: Public + + public var inputs: [String: MPSGraphTensor] { + firstGraph.inputs + } + + public var outputs: [String: MPSGraphTensor] { + lastGraph.outputs + } + + public var input: MPSGraphTensor { + firstGraph.input + } + + public var output: MPSGraphTensor { + lastGraph.output + } + + /// single input -> single output + public func callAsFunction( + _ input: MPSGraphTensorData, + in commandBuffer: MPSCommandBuffer + ) throws -> MPSGraphTensorData { + try encode(inputs: [firstGraph.inputs.first!.key: input], in: commandBuffer).first!.value + } + + /// multiple inputs -> single output + public func callAsFunction( + _ inputs: [String: MPSGraphTensorData], + in commandBuffer: MPSCommandBuffer + ) throws -> MPSGraphTensorData { + try encode(inputs: inputs, in: commandBuffer).first!.value + } + + /// single input -> multiple outputs + public func callAsFunction( + _ input: MPSGraphTensorData, + in commandBuffer: MPSCommandBuffer + ) throws -> [String: MPSGraphTensorData] { + try encode(inputs: [firstGraph.inputs.first!.key: input], in: commandBuffer) + } + + /// multiple inputs -> multiple outputs + public func callAsFunction( + _ inputs: [String: MPSGraphTensorData], + in commandBuffer: MPSCommandBuffer + ) throws -> [String: MPSGraphTensorData] { + try encode(inputs: inputs, in: commandBuffer) + } + + // MARK: Private + + private enum PipelineStep { + case graph(MPSCompiledGraph) + case custom(CustomNode) + + // MARK: Internal + + struct CustomNode { + let proto: Onnx_NodeProto + let outputShapes: [[Int]] + let customNames: [String: String] + } + + var graph: MPSCompiledGraph? { + switch self { + case let .graph(value): + return value + case .custom: + return nil + } + } + } + + private let firstGraph: MPSCompiledGraph + private let lastGraph: MPSCompiledGraph + private let pipeline: [PipelineStep] + + private let customNodes: [String: OnnxCustomNode] + + private func encode( + inputs: [String: MPSGraphTensorData], + in commandBuffer: MPSCommandBuffer + ) throws -> [String: MPSGraphTensorData] { + var outputs = inputs + + for step in pipeline { + try autoreleasepool { + switch step { + case let .graph(graph): + outputs = graph(outputs, in: commandBuffer) + case let .custom(node): + let nodeOutputs = try self.customNodes[node.proto.opType]!.eval( + inputs: node.proto.input.map { + guard let output = outputs[$0] else { + throw OnnxError.incorrectCustomNodeImplementation( + opType: node.proto.opType, + reason: "Input named \($0) not found" + ) + } + return output + }, + outputShapes: node.outputShapes, + attributesProvider: node.proto, + in: commandBuffer + ) + + guard nodeOutputs.count == node.proto.output.count else { + throw OnnxError.incorrectCustomNodeImplementation( + opType: node.proto.opType, + reason: "Unexpected number of outputs" + ) + } + + // use customNames as onnx -> mps name table + + outputs = outputs.reduce(into: [:]) { + $0[node.customNames[$1.key, default: $1.key]] = $1.value + } + + zip(node.proto.output, nodeOutputs).forEach { outputs[node.customNames[$0, default: $0]] = $1 } + } + } + } + + return outputs + } +} + +public extension OnnxPipeline { + func warmUp(in commandBuffer: MPSCommandBuffer) { + for step in pipeline { + guard let graph = step.graph else { + continue + } + + let randomInputs: [String: MPSGraphTensorData] = MPSCompiledGraph(device: commandBuffer.device) { g in + graph.inputs.mapValues { t in + g.randomUniformTensor(withShape: t.shape ?? [], name: nil).cast(to: t.dataType) + } + }([:], in: commandBuffer) + + let _: [String: MPSGraphTensorData] = graph(randomInputs, in: commandBuffer) + } + } + + func inputsFromTextures(_ dict: [String: MTLTexture], in commandBuffer: MPSCommandBuffer) -> [String: MPSGraphTensorData] { + dict.reduce(into: [:]) { + if let input = self.inputs[$1.key] { + $0[$1.key] = .NCHW(texture: $1.value, matching: input, in: commandBuffer) + } else { + assertionFailure("no input with key \($1.key)") + } + } + } + + func imageFrom( + _ inputTexture: MTLTexture, + pixelFormat: MTLPixelFormat? = nil, + converter: MPSImageConversion? = nil, + in commandBuffer: MPSCommandBuffer + ) throws -> MPSTemporaryImage { + try self( + .NCHW(texture: inputTexture, matching: input, in: commandBuffer), + in: commandBuffer + ).nhwc(in: commandBuffer).temporaryImage2D( + pixelFormat: pixelFormat, + converter: converter, + in: commandBuffer + ) + } + + func texture2DFrom( + _ inputTexture: MTLTexture, + pixelFormat: MTLPixelFormat = .bgra8Unorm, + converter: MPSImageConversion, + in commandBuffer: MPSCommandBuffer + ) throws -> MTLTexture { + try self( + .NCHW(texture: inputTexture, matching: input, in: commandBuffer), + in: commandBuffer + ).nhwc(in: commandBuffer).texture2D( + pixelFormat: pixelFormat, + converter: converter, + in: commandBuffer + ) + } + + func arrayFrom( + _ inputTexture: MTLTexture, + in commandBuffer: MPSCommandBuffer + ) throws -> MPSNDArray { + try self( + .NCHW(texture: inputTexture, matching: input, in: commandBuffer), + in: commandBuffer + ).synchronizedNDArray(in: commandBuffer) + } +} diff --git a/Sources/MPSX/ONNX/Nodes/Resize.swift b/Sources/MPSX/ONNX/Nodes/Resize.swift index a06f777..e0ed551 100644 --- a/Sources/MPSX/ONNX/Nodes/Resize.swift +++ b/Sources/MPSX/ONNX/Nodes/Resize.swift @@ -21,7 +21,7 @@ extension MPSGraph { } return input.resize( - mode: node.attr(s: "mode") == "linear" ? .bilinear : .nearest, + mode: node.attr(s: "mode").contains("linear") ? .bilinear : .nearest, layout: .NCHW, height: Int((Float(shape.2) * scales.0).rounded()), width: Int((Float(shape.3) * scales.1).rounded()) diff --git a/Sources/MPSX/ONNX/OnnxError.swift b/Sources/MPSX/ONNX/OnnxError.swift index 01eea1b..8078a91 100644 --- a/Sources/MPSX/ONNX/OnnxError.swift +++ b/Sources/MPSX/ONNX/OnnxError.swift @@ -1,6 +1,8 @@ public enum OnnxError: Swift.Error { /// ONNX model has an inconsistent structure or some unsupported features case invalidModel(reason: String) + /// User-provided implementation of custom node is incorrect. See "reason" for details. + case incorrectCustomNodeImplementation(opType: String, reason: String) /// MPSX only supports a subset of the available ONNX operators, so this error will be thrown if an operator is not supported case unsupportedOperator(String) /// ONNX is a very volatile format with a bunch of opsets available, so if the layer input is invalid, this error will be thrown diff --git a/Sources/MPSXTests/OnnxExperimentalTests.swift b/Sources/MPSXTests/OnnxExperimentalTests.swift new file mode 100644 index 0000000..bd1ecd9 --- /dev/null +++ b/Sources/MPSXTests/OnnxExperimentalTests.swift @@ -0,0 +1,53 @@ +import MetalKit +import MetalPerformanceShaders +import MetalPerformanceShadersGraph +@testable import MPSX +import XCTest + +final class OnnxExperimentalTests: XCTestCase { + func testModelSplit() throws { + let model = try OnnxModel(data: data(bundlePath: "\(testResourcesPath)/shufflenet-v2-12.onnx")) + + let transposeSubgraphs: [String: Set] = [ + "Transpose_20": ["358"], + "Transpose_35": ["374"], + "Transpose_50": ["390"], + "Transpose_65": ["406"], + "Transpose_84": ["425"], + "Transpose_99": ["441"], + "Transpose_114": ["457"], + "Transpose_129": ["473"], + "Transpose_144": ["489"], + "Transpose_159": ["505"], + "Transpose_174": ["521"], + "Transpose_189": ["537"], + "Transpose_208": ["556"], + "Transpose_223": ["572"], + "Transpose_238": ["588"], + "Transpose_253": ["604"], + ] + + let concatSubgraphs: [String: Set] = [ + "Concat_17": ["347", "355"], + "Concat_32": ["362", "371"], + "Concat_47": ["378", "387"], + "Concat_62": ["394", "403"], + "Concat_81": ["414", "422"], + "Concat_96": ["429", "438"], + "Concat_111": ["445", "454"], + "Concat_126": ["461", "470"], + "Concat_141": ["477", "486"], + "Concat_156": ["493", "502"], + "Concat_171": ["509", "518"], + "Concat_186": ["525", "534"], + "Concat_205": ["545", "553"], + "Concat_220": ["560", "569"], + "Concat_235": ["576", "585"], + "Concat_250": ["592", "601"], + ] + + XCTAssertEqual(model.split(by: ["Transpose"]), transposeSubgraphs) + XCTAssertEqual(model.split(by: ["Concat"]), concatSubgraphs) + XCTAssertEqual(model.split(by: ["Concat", "Transpose"]), concatSubgraphs.merging(transposeSubgraphs, uniquingKeysWith: { x, _ in x })) + } +}