Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph split #3

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ONNX2MPSX.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def convert_onnx_to_mpsx(model, halfs):
onnx.helper.make_graph(
nodes=model.graph.node,
name=model.graph.name,
value_info=model.graph.value_info,
inputs=_fp32_to_fp16_info(
model.graph.input) if halfs else model.graph.input,
outputs=_fp32_to_fp16_info(
Expand Down
4 changes: 2 additions & 2 deletions Sources/MPSX/Foundation/DSL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ public extension MPSGraphTensor {
self,
size: [height, width].nsnumbers,
mode: mode,
centerResult: true,
alignCorners: false,
centerResult: false,
alignCorners: true,
layout: layout,
name: nil
)
Expand Down
8 changes: 8 additions & 0 deletions Sources/MPSX/Foundation/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,11 @@ public extension MPSGraphTensor {
shape?.map(\.intValue) ?? []
}
}

public extension MPSCommandBuffer {
func commitAndWait() {
let _rootCommandBuffer = rootCommandBuffer
commitAndContinue()
_rootCommandBuffer.waitUntilCompleted()
}
}
31 changes: 31 additions & 0 deletions Sources/MPSX/ONNX/Experimental/OnnxCustomNode.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import MetalPerformanceShadersGraph

public protocol OnnxAttributesProvider {
func attr(s name: String) -> String
func attr(f name: String) -> Float?
func attr(floats name: String) -> [Float]?
func attr(i name: String) -> Int?
func attr(ints name: String) -> [Int]?
}

public protocol OnnxCustomNode {
func preprocess(
inputTensor: MPSGraphTensor,
inputName: String,
graph: MPSGraph
) -> MPSGraphTensor

func postprocess(
outputName: String,
outputShape: [Int],
requiredDataType: MPSDataType,
graph: MPSGraph
) -> (placeholder: MPSGraphTensor, tensor: MPSGraphTensor)

func eval(
inputs: [MPSGraphTensorData],
outputShapes: [[Int]],
attributesProvider: OnnxAttributesProvider,
in commandBuffer: MPSCommandBuffer
) throws -> [MPSGraphTensorData]
}
80 changes: 80 additions & 0 deletions Sources/MPSX/ONNX/Experimental/OnnxModel+Split.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
extension OnnxModel {
/// Split the graph into subgraphs.
/// - Parameter splitOps: The names of the operators in the graph by which you want to split
/// - Returns: A map where key is the name of the node (node opType exists in splitOps) and value is an array of input tensors that are "alive" at the time of the split.
func split(by splitOps: Set<String>) -> [String: Set<String>] {
struct NodeCounter {
/// How many times the output of a node will be used as an input
var readCount: Int
/// Node lifetime: lower bound == first occurrence in the graph, upper bound == last occurrence as input in the graph.
var range: ClosedRange<Int>
}

let graph = proto.graph

var counters: [String: NodeCounter] = [:]
counters.reserveCapacity(graph.node.count)

var nodeIndices: [String: Int] = [:]
nodeIndices.reserveCapacity(graph.node.count)

// STEP 1: calculate counters for every node from start to end

for (index, node) in graph.node.enumerated() {
for outputName in node.output {
nodeIndices[outputName] = index
}

for inputName in node.input {
// only runtime tensors

guard initializer[inputName] == nil else {
continue
}

let inputIndex = nodeIndices[inputName, default: 0]

let counter = counters[inputName, default: .init(readCount: 0, range: inputIndex ... inputIndex)]

counters[inputName] = .init(
readCount: counter.readCount + 1,
range: counter.range.lowerBound ... index
)
}
}

// STEP 2: now we have counters for every node and can split graph into subgraphs using this info.

var subgraphs: [String: Set<String>] = [:]

for (index, node) in graph.node.enumerated() {
for inputName in node.input {
guard initializer[inputName] == nil,
var counter = counters[inputName]
else {
continue
}

if counter.readCount > 0 {
counter.readCount -= 1

counters[inputName] = counter
} else {
counters[inputName] = nil
}
}

// split

if splitOps.contains(node.opType) {
// TODO: optimize using some tree data structure/sorted collection

let inputs = counters.filter { $0.value.range.contains(index) }

subgraphs[node.name] = Set(inputs.map(\.key)).subtracting(Set(node.output))
}
}

return subgraphs
}
}
Loading