Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
geor-kasapidi committed Apr 19, 2023
1 parent 86f2ac9 commit de43cf0
Show file tree
Hide file tree
Showing 9 changed files with 588 additions and 3 deletions.
1 change: 1 addition & 0 deletions ONNX2MPSX.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def convert_onnx_to_mpsx(model, halfs):
onnx.helper.make_graph(
nodes=model.graph.node,
name=model.graph.name,
value_info=model.graph.value_info,
inputs=_fp32_to_fp16_info(
model.graph.input) if halfs else model.graph.input,
outputs=_fp32_to_fp16_info(
Expand Down
4 changes: 2 additions & 2 deletions Sources/MPSX/Foundation/DSL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ public extension MPSGraphTensor {
self,
size: [height, width].nsnumbers,
mode: mode,
centerResult: true,
alignCorners: false,
centerResult: false,
alignCorners: true,
layout: layout,
name: nil
)
Expand Down
8 changes: 8 additions & 0 deletions Sources/MPSX/Foundation/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,11 @@ public extension MPSGraphTensor {
shape?.map(\.intValue) ?? []
}
}

public extension MPSCommandBuffer {
func commitAndWait() {
let _rootCommandBuffer = rootCommandBuffer
commitAndContinue()
_rootCommandBuffer.waitUntilCompleted()
}
}
31 changes: 31 additions & 0 deletions Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import MetalPerformanceShadersGraph

public protocol OnnxAttributesProvider {
func attr(s name: String) -> String
func attr(f name: String) -> Float?
func attr(floats name: String) -> [Float]?
func attr(i name: String) -> Int?
func attr(ints name: String) -> [Int]?
}

public protocol OnnxCustomNode {
func preprocess(
inputTensor: MPSGraphTensor,
inputName: String,
graph: MPSGraph
) -> MPSGraphTensor

func postprocess(
outputName: String,
outputShape: [Int],
requiredDataType: MPSDataType,
graph: MPSGraph
) -> (placeholder: MPSGraphTensor, tensor: MPSGraphTensor)

func eval(
inputs: [MPSGraphTensorData],
outputShapes: [[Int]],
attributesProvider: OnnxAttributesProvider,
in commandBuffer: MPSCommandBuffer
) throws -> [MPSGraphTensorData]
}
80 changes: 80 additions & 0 deletions Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
extension OnnxModel {
/// Split the graph into subgraphs.
/// - Parameter splitOps: The names of the operators in the graph by which you want to split
/// - Returns: A map where key is the name of the node (node opType exists in splitOps) and value is an array of input tensors that are "alive" at the time of the split.
func split(by splitOps: Set<String>) -> [String: Set<String>] {
struct NodeCounter {
/// How many times the output of a node will be used as an input
var readCount: Int
/// Node lifetime: lower bound == first occurrence in the graph, upper bound == last occurrence as input in the graph.
var range: ClosedRange<Int>
}

let graph = proto.graph

var counters: [String: NodeCounter] = [:]
counters.reserveCapacity(graph.node.count)

var nodeIndices: [String: Int] = [:]
nodeIndices.reserveCapacity(graph.node.count)

// STEP 1: calculate counters for every node from start to end

for (index, node) in graph.node.enumerated() {
for outputName in node.output {
nodeIndices[outputName] = index
}

for inputName in node.input {
// only runtime tensors

guard initializer[inputName] == nil else {
continue
}

let inputIndex = nodeIndices[inputName, default: 0]

let counter = counters[inputName, default: .init(readCount: 0, range: inputIndex ... inputIndex)]

counters[inputName] = .init(
readCount: counter.readCount + 1,
range: counter.range.lowerBound ... index
)
}
}

// STEP 2: now we have counters for every node and can split graph into subgraphs using this info.

var subgraphs: [String: Set<String>] = [:]

for (index, node) in graph.node.enumerated() {
for inputName in node.input {
guard initializer[inputName] == nil,
var counter = counters[inputName]
else {
continue
}

if counter.readCount > 0 {
counter.readCount -= 1

counters[inputName] = counter
} else {
counters[inputName] = nil
}
}

// split

if splitOps.contains(node.opType) {
// TODO: optimize using some tree data structure/sorted collection

let inputs = counters.filter { $0.value.range.contains(index) }

subgraphs[node.name] = Set(inputs.map(\.key)).subtracting(Set(node.output))
}
}

return subgraphs
}
}
Loading

0 comments on commit de43cf0

Please sign in to comment.