Skip to content

Commit

Permalink
less allocations, more tests, new api
Browse files Browse the repository at this point in the history
  • Loading branch information
geor-kasapidi committed Feb 19, 2023
1 parent 2b9be21 commit f3f52ce
Show file tree
Hide file tree
Showing 13 changed files with 437 additions and 352 deletions.
21 changes: 17 additions & 4 deletions Sources/MPSX/Foundation/CompiledGraph.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,32 @@ import MetalPerformanceShadersGraph
public final class MPSCompiledGraph {
// MARK: Lifecycle

public init(
public convenience init(
options: MPSGraphOptions = .none,
compilationDescriptor: MPSGraphCompilationDescriptor? = nil,
device: MTLDevice,
body: (MPSGraph) throws -> [String: MPSGraphTensor]
) rethrows {
let graph = MPSGraph()
graph.options = options
let graph = MPSGraph(options: options)

let outputTensors = try autoreleasepool {
try body(graph)
}

self.init(
compilationDescriptor: compilationDescriptor,
device: device,
graph: graph,
outputTensors: outputTensors
)
}

internal init(
compilationDescriptor: MPSGraphCompilationDescriptor? = nil,
device: MTLDevice,
graph: MPSGraph,
outputTensors: [String: MPSGraphTensor]
) {
let executable = autoreleasepool {
graph.compile(
with: .init(mtlDevice: device),
Expand All @@ -28,7 +41,7 @@ public final class MPSCompiledGraph {
compilationDescriptor: compilationDescriptor
)
}
executable.options = options
executable.options = graph.options

let outputKeys = outputTensors.reduce(into: [:]) { $0[$1.value.operation.name] = $1.key }

Expand Down
2 changes: 1 addition & 1 deletion Sources/MPSX/Foundation/DSL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public extension MPSGraph {
constant(vector.rawData, shape: (shape ?? [vector.count]).nsnumbers, dataType: .float32)
}

#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64))
#if !arch(x86_64)
@inlinable
func const(_ vector: [Swift.Float16], shape: [Int]? = nil) -> MPSGraphTensor {
constant(vector.rawData, shape: (shape ?? [vector.count]).nsnumbers, dataType: .float16)
Expand Down
105 changes: 0 additions & 105 deletions Sources/MPSX/Foundation/FPAC.swift

This file was deleted.

76 changes: 76 additions & 0 deletions Sources/MPSX/Foundation/FPC.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import Accelerate
import Foundation

#if arch(x86_64)
typealias Float16 = UInt16
#endif

/// fast floating point conversion
enum FPC {
// MARK: Internal

static func _Float32_Float16<T: AccelerateBuffer>(_ input: T) -> [Float16] where T.Element == Float {
convert(input, vImageConvert_PlanarFtoPlanar16F)
}

static func _Float16_Float32<T: AccelerateBuffer>(_ input: T) -> [Float] where T.Element == Float16 {
convert(input, vImageConvert_Planar16FtoPlanarF)
}

static func _Int8_Float32<T: AccelerateBuffer>(_ input: T) -> [Float] where T.Element == Int8 {
convert(input, vDSP.convertElements)
}

static func _Int16_Float32<T: AccelerateBuffer>(_ input: T) -> [Float] where T.Element == Int16 {
convert(input, vDSP.convertElements)
}

static func _Int32_Float32<T: AccelerateBuffer>(_ input: T) -> [Float] where T.Element == Int32 {
convert(input, vDSP.convertElements)
}

static func _UInt8_Float32<T: AccelerateBuffer>(_ input: T) -> [Float] where T.Element == UInt8 {
convert(input, vDSP.convertElements)
}

static func _UInt16_Float32<T: AccelerateBuffer>(_ input: T) -> [Float] where T.Element == UInt16 {
convert(input, vDSP.convertElements)
}

static func _UInt32_Float32<T: AccelerateBuffer>(_ input: T) -> [Float] where T.Element == UInt32 {
convert(input, vDSP.convertElements)
}

// MARK: Private

@_transparent
private static func convert<U: AccelerateBuffer, V: Numeric>(_ input: U, _ body: (U, inout UnsafeMutableBufferPointer<V>) -> Void) -> [V] {
.init(unsafeUninitializedCapacity: input.count) { buffer, initializedCount in
body(input, &buffer)
initializedCount = input.count
}
}

private static func convert<U: AccelerateBuffer, V: Numeric>(
_ input: U,
_ body: (UnsafePointer<vImage_Buffer>, UnsafePointer<vImage_Buffer>, vImage_Flags) -> vImage_Error
) -> [V] where U.Element: Numeric {
var output = [V](repeating: 0, count: input.count)

@_transparent
func buffer<T>(of _: T.Type, pointer: UnsafeMutableRawPointer, count: Int) -> vImage_Buffer {
.init(data: pointer, height: 1, width: UInt(count), rowBytes: count * MemoryLayout<T>.stride)
}

input.withUnsafeBufferPointer { inputPointer in
output.withUnsafeMutableBufferPointer { outputPointer in
var inputBuffer = buffer(of: U.self, pointer: .init(mutating: inputPointer.baseAddress!), count: inputPointer.count)
var outputBuffer = buffer(of: V.self, pointer: .init(mutating: outputPointer.baseAddress!), count: outputPointer.count)

_ = body(&inputBuffer, &outputBuffer, 0)
}
}

return output
}
}
Loading

0 comments on commit f3f52ce

Please sign in to comment.