From 63588740f6823e172a2a71603cca0c956707bba1 Mon Sep 17 00:00:00 2001 From: Geor Kasapidi Date: Mon, 27 Feb 2023 13:31:22 +0200 Subject: [PATCH] wip --- Package.swift | 6 +- .../ONNX/Experimental/OnnxCustomNode.swift | 31 ++ .../ONNX/Experimental/OnnxModel+Split.swift | 80 ++++ .../MPSX/ONNX/Experimental/OnnxPipeline.swift | 372 ++++++++++++++++++ Sources/MPSX/ONNX/Nodes/Arithmetic.swift | 2 +- Sources/MPSX/ONNX/OnnxError.swift | 2 + Sources/MPSXTests/OnnxTests.swift | 230 ++++++++++- 7 files changed, 720 insertions(+), 3 deletions(-) create mode 100644 Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift create mode 100644 Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift create mode 100644 Sources/MPSX/ONNX/Experimental/OnnxPipeline.swift diff --git a/Package.swift b/Package.swift index 0106e0e..0b7c1e7 100644 --- a/Package.swift +++ b/Package.swift @@ -17,6 +17,10 @@ let package = Package( url: "https://github.com/apple/swift-protobuf.git", from: "1.20.1" ), + .package( + url: "https://github.com/prisma-ai/AFFT", + branch: "main" + ), ], targets: [ .target( @@ -27,7 +31,7 @@ let package = Package( ), .testTarget( name: "MPSXTests", - dependencies: ["MPSX"] + dependencies: ["MPSX", "AFFT"] ), ] ) diff --git a/Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift b/Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift new file mode 100644 index 0000000..f40c07e --- /dev/null +++ b/Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift @@ -0,0 +1,31 @@ +import MetalPerformanceShadersGraph + +public protocol OnnxAttributesProvider { + func attr(s name: String) -> String + func attr(f name: String) -> Float? + func attr(floats name: String) -> [Float]? + func attr(i name: String) -> Int? + func attr(ints name: String) -> [Int]? +} + +public protocol OnnxCustomNode { + func preprocess( + inputTensor: MPSGraphTensor, + inputName: String, + graph: MPSGraph + ) -> MPSGraphTensor + + func postprocess( + outputName: String, + outputShape: [Int], + requiredDataType: MPSDataType, + graph: MPSGraph + ) -> (placeholder: MPSGraphTensor, tensor: MPSGraphTensor) + + func eval( + inputs: [MPSGraphTensorData], + outputShapes: [[Int]], + attributesProvider: OnnxAttributesProvider, + in commandBuffer: MPSCommandBuffer + ) throws -> [MPSGraphTensorData] +} diff --git a/Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift b/Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift new file mode 100644 index 0000000..c5420b7 --- /dev/null +++ b/Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift @@ -0,0 +1,80 @@ +extension OnnxModel { + /// Split the graph into subgraphs. + /// - Parameter splitOps: The names of the operators in the graph by which you want to split + /// - Returns: A map where key is the name of the node (node opType exists in splitOps) and value is an array of input tensors that are "alive" at the time of the split. + func split(by splitOps: Set) -> [String: Set] { + struct NodeCounter { + /// How many times the output of a node will be used as an input + var readCount: Int + /// Node lifetime: lower bound == first occurrence in the graph, upper bound == last occurrence as input in the graph. + var range: ClosedRange + } + + let graph = proto.graph + + var counters: [String: NodeCounter] = [:] + counters.reserveCapacity(graph.node.count) + + var nodeIndices: [String: Int] = [:] + nodeIndices.reserveCapacity(graph.node.count) + + // STEP 1: calculate counters for every node from start to end + + for (index, node) in graph.node.enumerated() { + for outputName in node.output { + nodeIndices[outputName] = index + } + + for inputName in node.input { + // only runtime tensors + + guard initializer[inputName] == nil else { + continue + } + + let inputIndex = nodeIndices[inputName, default: 0] + + let counter = counters[inputName, default: .init(readCount: 0, range: inputIndex ... inputIndex)] + + counters[inputName] = .init( + readCount: counter.readCount + 1, + range: counter.range.lowerBound ... index + ) + } + } + + // STEP 2: now we have counters for every node and can split graph into subgraphs using this info. + + var subgraphs: [String: Set] = [:] + + for (index, node) in graph.node.enumerated() { + for inputName in node.input { + guard initializer[inputName] == nil, + var counter = counters[inputName] + else { + continue + } + + if counter.readCount > 0 { + counter.readCount -= 1 + + counters[inputName] = counter + } else { + counters[inputName] = nil + } + } + + // split + + if splitOps.contains(node.opType) { + // TODO: optimize using some tree data structure/sorted collection + + let inputs = counters.filter { $0.value.range.contains(index) } + + subgraphs[node.name] = Set(inputs.map(\.key)).subtracting(Set(node.output)) + } + } + + return subgraphs + } +} diff --git a/Sources/MPSX/ONNX/Experimental/OnnxPipeline.swift b/Sources/MPSX/ONNX/Experimental/OnnxPipeline.swift new file mode 100644 index 0000000..ed0a537 --- /dev/null +++ b/Sources/MPSX/ONNX/Experimental/OnnxPipeline.swift @@ -0,0 +1,372 @@ +import MetalPerformanceShadersGraph + +extension Onnx_NodeProto: OnnxAttributesProvider {} + +/// Experimental version of OnnxGraph with support for custom implementations for unknown operators +public final class OnnxPipeline { + // MARK: Lifecycle + + /// Initialize pipeline instance + /// - Parameters: + /// - model: onnx model + /// - device: metal device for graph compilation + /// - customNodes: user-provided implementations for unknown operations + /// - config: graph building configuration + public init( + model: OnnxModel, + device: MTLDevice, + customNodes: [String: OnnxCustomNode], + config: OnnxGraphConfig = .init() + ) throws { + let onnxGraph = model.proto.graph + + // required for custom operations + + let valueInfo: [String: Onnx_TypeProto.Tensor] = onnxGraph.valueInfo.reduce(into: [:]) { + if case let .tensorType(tensor) = $1.type.value { + $0[$1.name] = tensor // this object contains information about shape, datatype, etc. + } + } + + let splitResult = customNodes.isEmpty ? [:] : model.split(by: Set(customNodes.keys)) + + let populateTensorsOnTheFly = !splitResult.isEmpty + + let graphOptions: MPSGraphOptions = .none + let compilationDescriptor: MPSGraphCompilationDescriptor? = nil + let tensorsDataType = config.tensorsDataType.mpsDataType // fp16 or fp32 + + var constants = model.initializer // weights, biases, etc. + + var pipeline: [PipelineStep] = [] + + var mpsGraph = MPSGraph(options: graphOptions) // current mpsgraph -> will be overwritten if split result is not empty + var mpsTensors: [String: MPSGraphTensor] // table of current mpsgraph tensors for faster lookup by name + + if populateTensorsOnTheFly { + mpsTensors = [:] + } else { + mpsTensors = try model.initializer.mapValues { + try mpsGraph.constant($0, targetDataType: tensorsDataType) + } + } + + // create placeholders + + for input in model.inputs { + let options = config.inputs[input.name] + + let shape = input.shape.enumerated().map { + options?.dims?[$0.offset] ?? Int($0.element) + } + + mpsTensors[input.name] = mpsGraph.input( + shape: shape.nsnumbers, + dataType: tensorsDataType, + valuesRange: options?.valuesRange, + name: input.name + ) + } + + // iterate over graph: onnx guarantees graph nodes are sorted topologically + + for node in onnxGraph.node { + if populateTensorsOnTheFly { + for input in node.input { + if let tensor = model.initializer[input], mpsTensors[input] == nil { + mpsTensors[input] = try mpsGraph.constant(tensor, targetDataType: tensorsDataType) + } + } + } + + let success = try mpsGraph.onnx( + node: node, + optimizedForMPS: model.optimizedForMPS, + tensorsDataType: tensorsDataType, + tensors: &mpsTensors, + constants: &constants + ) + + if success { + continue // known operation + } + + // lookup for a user-provided implementation + + guard let customNode = customNodes[node.opType], + let subgraphOutputs = splitResult[node.name] + else { + throw OnnxError.unsupportedOperator(node.opType) // custom implementation not found + } + + // generate node inputs using current mpsgraph instance and tensor table + + try node.input.forEach { + guard let inputTensor = mpsTensors[$0] else { + throw OnnxError.invalidModel(reason: "Tensor named \($0) not found") + } + + mpsTensors[$0] = customNode.preprocess( + inputTensor: inputTensor, + inputName: $0, + graph: mpsGraph + ) + } + + // finalize current mps graph + + let outputTensors = subgraphOutputs.reduce(into: [:]) { + $0[$1] = mpsTensors[$1] + } + + let subgraph = MPSCompiledGraph( + compilationDescriptor: compilationDescriptor, + device: device, + graph: mpsGraph, + outputTensors: outputTensors + ) + + // create next mpsgraph instance and tensor table + + mpsGraph = .init(options: graphOptions) + mpsTensors = [:] + + var customNames: [String: String] = [:] + + outputTensors.forEach { + let placeholder = mpsGraph.placeholder(shape: $1.shape, dataType: $1.dataType, name: $0) + + mpsTensors[$0] = placeholder + + // ⚠️ mps can change the user-defined name: ex. dots (.) will be replaced with underscores (_) + + if placeholder.operation.name != $0 { + customNames[$0] = placeholder.operation.name + } + } + + // generate node outputs and placeholders + + let outputShapes = try node.output.map { + guard let tensorInfo = valueInfo[$0], tensorInfo.hasShape else { + throw OnnxError.invalidModel(reason: "Shaped tensor named \($0) not found") + } + + let shape = tensorInfo.shape.dim.map { Int($0.dimValue) } + + let (placeholder, tensor) = customNode.postprocess( + outputName: $0, + outputShape: shape, + requiredDataType: tensorsDataType, + graph: mpsGraph + ) + + if placeholder.operation.name != $0 { + customNames[$0] = placeholder.operation.name + } + + guard (tensor.shape ?? []).map(\.intValue) == shape else { + throw OnnxError.incorrectCustomNodeImplementation( + opType: node.opType, + reason: "Shape of tensor named \($0) does not match the required \(shape)" + ) + } + + mpsTensors[$0] = tensor + + return shape + } + + // pipeline++ + + pipeline.append(.graph(subgraph)) + pipeline.append(.custom(.init(proto: node, outputShapes: outputShapes, customNames: customNames))) + } + + // final step: setup onnx graph outputs + + let finalGraph = try MPSCompiledGraph( + compilationDescriptor: compilationDescriptor, + device: device, + graph: mpsGraph, + outputTensors: model.outputs.reduce(into: [:]) { + guard let tensor = mpsTensors[$1] else { + throw OnnxError.invalidModel(reason: "Tensor named \($1) not found") + } + + $0[$1] = mpsGraph.output( + tensor: tensor, + valuesRange: config.outputs[$1]?.valuesRange + ) + } + ) + + pipeline.append(.graph(finalGraph)) + + firstGraph = pipeline.first!.graph! + lastGraph = pipeline.last!.graph! + + self.pipeline = pipeline + self.customNodes = customNodes + } + + // MARK: Public + + public var inputs: [String: MPSGraphTensor] { + firstGraph.inputs + } + + public var outputs: [String: MPSGraphTensor] { + lastGraph.outputs + } + + /// single input -> single output + public func callAsFunction( + _ input: MPSGraphTensorData, + in commandBuffer: MPSCommandBuffer + ) throws -> MPSGraphTensorData { + try encode(inputs: [firstGraph.inputs.first!.key: input], in: commandBuffer).first!.value + } + + /// multiple inputs -> single output + public func callAsFunction( + _ inputs: [String: MPSGraphTensorData], + in commandBuffer: MPSCommandBuffer + ) throws -> MPSGraphTensorData { + try encode(inputs: inputs, in: commandBuffer).first!.value + } + + /// single input -> multiple outputs + public func callAsFunction( + _ input: MPSGraphTensorData, + in commandBuffer: MPSCommandBuffer + ) throws -> [String: MPSGraphTensorData] { + try encode(inputs: [firstGraph.inputs.first!.key: input], in: commandBuffer) + } + + /// multiple inputs -> multiple outputs + public func callAsFunction( + _ inputs: [String: MPSGraphTensorData], + in commandBuffer: MPSCommandBuffer + ) throws -> [String: MPSGraphTensorData] { + try encode(inputs: inputs, in: commandBuffer) + } + + // MARK: Private + + private enum PipelineStep { + case graph(MPSCompiledGraph) + case custom(CustomNode) + + // MARK: Internal + + struct CustomNode { + let proto: Onnx_NodeProto + let outputShapes: [[Int]] + let customNames: [String: String] + } + + var graph: MPSCompiledGraph? { + switch self { + case let .graph(value): + return value + case .custom: + return nil + } + } + } + + private let firstGraph: MPSCompiledGraph + private let lastGraph: MPSCompiledGraph + private let pipeline: [PipelineStep] + + private let customNodes: [String: OnnxCustomNode] + + private func encode( + inputs: [String: MPSGraphTensorData], + in commandBuffer: MPSCommandBuffer + ) throws -> [String: MPSGraphTensorData] { + var outputs = inputs + + for step in pipeline { + try autoreleasepool { + switch step { + case let .graph(graph): + outputs = graph(outputs, in: commandBuffer) + case let .custom(node): + let nodeOutputs = try self.customNodes[node.proto.opType]!.eval( + inputs: node.proto.input.map { + guard let output = outputs[$0] else { + throw OnnxError.incorrectCustomNodeImplementation( + opType: node.proto.opType, + reason: "Input named \($0) not found" + ) + } + return output + }, + outputShapes: node.outputShapes, + attributesProvider: node.proto, + in: commandBuffer + ) + + guard nodeOutputs.count == node.proto.output.count else { + throw OnnxError.incorrectCustomNodeImplementation( + opType: node.proto.opType, + reason: "Unexpected number of outputs" + ) + } + + // use customNames as onnx -> mps name table + + outputs = outputs.reduce(into: [:]) { + $0[node.customNames[$1.key, default: $1.key]] = $1.value + } + + zip(node.proto.output, nodeOutputs).forEach { outputs[node.customNames[$0, default: $0]] = $1 } + } + } + } + + return outputs + } +} + +public extension OnnxPipeline { + func warmUp(in commandBuffer: MPSCommandBuffer) { + for step in pipeline { + guard let graph = step.graph else { + continue + } + + let randomInputs: [String: MPSGraphTensorData] = MPSCompiledGraph(device: commandBuffer.device) { g in + graph.inputs.mapValues { t in + g.randomUniformTensor(withShape: t.shape ?? [], name: nil).cast(to: t.dataType) + } + }([:], in: commandBuffer) + + let _: [String: MPSGraphTensorData] = graph(randomInputs, in: commandBuffer) + } + } + + func imageFrom( + _ inputTexture: MTLTexture, + in commandBuffer: MPSCommandBuffer + ) throws -> MPSTemporaryImage { + try self( + .NCHW(texture: inputTexture, matching: inputs.first!.value, in: commandBuffer), + in: commandBuffer + ).nhwc(in: commandBuffer).temporaryImage(in: commandBuffer) + } + + func texture2DFrom( + _ inputTexture: MTLTexture, + pixelFormat: MTLPixelFormat = .bgra8Unorm, + converter: MPSImageConversion, + in commandBuffer: MPSCommandBuffer + ) throws -> MTLTexture { + try self( + .NCHW(texture: inputTexture, matching: inputs.first!.value, in: commandBuffer), + in: commandBuffer + ).nhwc(in: commandBuffer).texture2D(pixelFormat: pixelFormat, converter: converter, in: commandBuffer) + } +} diff --git a/Sources/MPSX/ONNX/Nodes/Arithmetic.swift b/Sources/MPSX/ONNX/Nodes/Arithmetic.swift index da5b3ab..7d1434a 100644 --- a/Sources/MPSX/ONNX/Nodes/Arithmetic.swift +++ b/Sources/MPSX/ONNX/Nodes/Arithmetic.swift @@ -19,7 +19,7 @@ extension MPSGraph { else { throw OnnxError.invalidInput(node.name) } switch op { - case .add: return a + b + case .add: return a + b.cast(to: a.dataType) case .sub: return a - b case .mul: return a * b case .div: return a / b diff --git a/Sources/MPSX/ONNX/OnnxError.swift b/Sources/MPSX/ONNX/OnnxError.swift index 01eea1b..8078a91 100644 --- a/Sources/MPSX/ONNX/OnnxError.swift +++ b/Sources/MPSX/ONNX/OnnxError.swift @@ -1,6 +1,8 @@ public enum OnnxError: Swift.Error { /// ONNX model has an inconsistent structure or some unsupported features case invalidModel(reason: String) + /// User-provided implementation of custom node is incorrect. See "reason" for details. + case incorrectCustomNodeImplementation(opType: String, reason: String) /// MPSX only supports a subset of the available ONNX operators, so this error will be thrown if an operator is not supported case unsupportedOperator(String) /// ONNX is a very volatile format with a bunch of opsets available, so if the layer input is invalid, this error will be thrown diff --git a/Sources/MPSXTests/OnnxTests.swift b/Sources/MPSXTests/OnnxTests.swift index d883d2e..43681dd 100644 --- a/Sources/MPSXTests/OnnxTests.swift +++ b/Sources/MPSXTests/OnnxTests.swift @@ -1,9 +1,118 @@ +import Accelerate +import AFFT import MetalKit import MetalPerformanceShaders import MetalPerformanceShadersGraph -import MPSX +@testable import MPSX import XCTest +extension MPSCommandBuffer { + func commitAndWait() { + let cmdbuf = rootCommandBuffer + commitAndContinue() + cmdbuf.waitUntilCompleted() + } +} + +@available(macOS 12.3, *) +final class FFTForwardPassNode: OnnxCustomNode { + func preprocess(inputTensor: MPSGraphTensor, inputName _: String, graph _: MPSGraph) -> MPSGraphTensor { + inputTensor.cast(to: .float32) + } + + func postprocess(outputName: String, outputShape: [Int], requiredDataType: MPSDataType, graph: MPSGraph) -> (placeholder: MPSGraphTensor, tensor: MPSGraphTensor) { + FFT2DN.postprocessForwardPass( + shape: .init(channels: outputShape[1] / 2, height: outputShape[2], width: (outputShape[3] - 1) * 2), + graph: graph, + requiredDataType: requiredDataType, + placeholderName: outputName + ) + } + + func eval( + inputs: [MPSGraphTensorData], + outputShapes _: [[Int]], + attributesProvider _: MPSX.OnnxAttributesProvider, + in commandBuffer: MPSCommandBuffer + ) throws -> [MPSGraphTensorData] { + let input = inputs[0] + let shape = input.shape.map(\.intValue) + let ndarray = input.synchronizedNDArray(in: commandBuffer) + + commandBuffer.commitAndWait() + + let floats = ndarray.floats + + let outputs = FFT2DN.transform( + input: floats, + shape: .init(channels: shape[1], height: shape[2], width: shape[3]), + direction: .forward, + realInput: true + ) + + let data: MPSGraphTensorData = .floats(outputs, shape: [1, shape[1] * 2, shape[2], shape[3]], device: commandBuffer.device) + + return [data] + } +} + +@available(macOS 12.3, *) +final class FFTInversePassNode: OnnxCustomNode { + func preprocess(inputTensor: MPSGraphTensor, inputName _: String, graph: MPSGraph) -> MPSGraphTensor { + let input = inputTensor + let shape = input.quadShape! + + return FFT2DN.preprocessInversePass( + shape: .init(channels: shape.1 / 2, height: shape.2, width: (shape.3 - 1) * 2), + graph: graph, + input: input + ).cast(to: .float32) + } + + func postprocess(outputName: String, outputShape: [Int], requiredDataType: MPSDataType, graph: MPSGraph) -> (placeholder: MPSGraphTensor, tensor: MPSGraphTensor) { + let shape = outputShape + + let placeholder = graph.placeholder(shape: [1, shape[1], shape[2], shape[3]].nsnumbers, dataType: .float32, name: outputName) + + return (placeholder, placeholder.cast(to: requiredDataType)) + } + + func eval( + inputs: [MPSGraphTensorData], + outputShapes _: [[Int]], + attributesProvider _: MPSX.OnnxAttributesProvider, + in commandBuffer: MPSCommandBuffer + ) throws -> [MPSGraphTensorData] { + let input = inputs[0] + let shape = input.shape.map(\.intValue) + let ndarray = input.synchronizedNDArray(in: commandBuffer) + + commandBuffer.commitAndWait() + + let floats = ndarray.floats + + let outputs = FFT2DN.transform( + input: floats, + shape: .init(channels: shape[1] / 2, height: shape[2], width: shape[3]), + direction: .inverse, + realInput: false + ) + + let realPart = outputs[0 ..< outputs.count / 2].rawData + + let data = MPSGraphTensorData( + device: .init(mtlDevice: commandBuffer.device), + data: realPart, + shape: [1, shape[1] / 2, shape[2], shape[3]].nsnumbers, + dataType: .float32 + ) + +// let data: MPSGraphTensorData = .floats(realPart, shape: shape, device: commandBuffer.device) + + return [data] + } +} + @available(macOS 12.0, *) final class OnnxTests: XCTestCase { // https://github.com/onnx/models/tree/main/vision/classification/shufflenet @@ -139,4 +248,123 @@ final class OnnxTests: XCTestCase { try save(texture: textureToSave, arg: 3) } + + func testModelSplit() throws { + let model = try OnnxModel(data: data(arg: 1)) // shufflenet-v2-12.onnx + + let transposeSubgraphs: [String: Set] = [ + "Transpose_20": ["358"], + "Transpose_35": ["374"], + "Transpose_50": ["390"], + "Transpose_65": ["406"], + "Transpose_84": ["425"], + "Transpose_99": ["441"], + "Transpose_114": ["457"], + "Transpose_129": ["473"], + "Transpose_144": ["489"], + "Transpose_159": ["505"], + "Transpose_174": ["521"], + "Transpose_189": ["537"], + "Transpose_208": ["556"], + "Transpose_223": ["572"], + "Transpose_238": ["588"], + "Transpose_253": ["604"], + ] + + let concatSubgraphs: [String: Set] = [ + "Concat_17": ["347", "355"], + "Concat_32": ["362", "371"], + "Concat_47": ["378", "387"], + "Concat_62": ["394", "403"], + "Concat_81": ["414", "422"], + "Concat_96": ["429", "438"], + "Concat_111": ["445", "454"], + "Concat_126": ["461", "470"], + "Concat_141": ["477", "486"], + "Concat_156": ["493", "502"], + "Concat_171": ["509", "518"], + "Concat_186": ["525", "534"], + "Concat_205": ["545", "553"], + "Concat_220": ["560", "569"], + "Concat_235": ["576", "585"], + "Concat_250": ["592", "601"], + ] + + XCTAssertEqual(model.split(by: ["Transpose"]), transposeSubgraphs) + XCTAssertEqual(model.split(by: ["Concat"]), concatSubgraphs) + XCTAssertEqual(model.split(by: ["Concat", "Transpose"]), concatSubgraphs.merging(transposeSubgraphs, uniquingKeysWith: { x, _ in x })) + } + + @available(macOS 12.3, *) + func testFFT() async throws { + // STEP 0️⃣: setup model + + // ⚠️⚠️⚠️ You can find required files in 1.1.1 release attachments + + let model = try OnnxModel(data: data(arg: 1)) // candy-8.onnx + + // STEP 1️⃣: setup metal stuff + + let gpu = GPU.default + + // STEP 2️⃣: create onnx graph using model instance, metal device and graph configuration + + let graph = try OnnxPipeline( + model: model, + device: gpu.device, + customNodes: [ + "rfftn": FFTForwardPassNode(), + "irfftn": FFTInversePassNode(), + ], + config: .init(tensorsDataType: .fp16) + ) + + // STEP 3️⃣: prepare inputs and warm up graph + + let inputImage = try await inputTexture(arg: 2) + let inputMask = try await inputTexture(arg: 3) + + let inputs: [String: MPSGraphTensorData] = gpu.commandQueue.sync { + // ❕ This call is optional: first run of the graph is slower than the others, so for clear measurements we perform warm-up. + + graph.warmUp(in: $0) + + return [ + "image": .NCHW( + texture: inputImage, + matching: graph.inputs["image"]!, + in: $0 + ), + "mask": .NCHW( + texture: inputMask, + matching: graph.inputs["mask"]!, + in: $0 + ), + ] + } + + // STEP 4️⃣: measure and run + + func styleTransfer() -> MPSGraphTensorData { + gpu.commandQueue.sync { + try! graph(inputs, in: $0) + } + } + + measure { + _ = styleTransfer() + } + + let rawData = styleTransfer() + + let textureToSave = gpu.commandQueue.sync { + rawData.nhwc(in: $0).texture2D( + pixelFormat: .rgba8Unorm, + converter: gpu.imageConverter, + in: $0 + ) + } + + try save(texture: textureToSave, arg: 4) + } }