Skip to content

Commit

Permalink
micro update
Browse files Browse the repository at this point in the history
  • Loading branch information
geor-kasapidi committed Feb 24, 2023
1 parent f3f52ce commit ab097be
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Sources/MPSX/Foundation/DSL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public extension MPSGraph {
constant(vector.rawData, shape: (shape ?? [vector.count]).nsnumbers, dataType: .float32)
}

#if !arch(x86_64)
#if arch(arm64)
@inlinable
func const(_ vector: [Swift.Float16], shape: [Int]? = nil) -> MPSGraphTensor {
constant(vector.rawData, shape: (shape ?? [vector.count]).nsnumbers, dataType: .float16)
Expand Down
2 changes: 1 addition & 1 deletion Sources/MPSX/Foundation/FPC.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Accelerate
import Foundation

#if arch(x86_64)
#if !arch(arm64)
typealias Float16 = UInt16
#endif

Expand Down
17 changes: 15 additions & 2 deletions Sources/MPSX/Foundation/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ extension Array {
}
}

extension ArraySlice {
@usableFromInline
var rawData: Data {
withUnsafeBufferPointer {
Data(buffer: $0) // copy
}
}
}

extension Data {
func mapMemory<T, R>(of _: T.Type, _ body: (UnsafeBufferPointer<T>) throws -> R) rethrows -> R {
try withUnsafeBytes {
Expand Down Expand Up @@ -182,10 +191,14 @@ private extension MPSGraphTensorData {

public extension MPSGraphTensorData {
static func floats(_ array: [Float], shape: [Int]? = nil, device: MTLDevice) -> MPSGraphTensorData {
.init(
let shape = shape ?? [array.count]

assert(shape.reduce(1, *) == array.count)

return .init(
device: .init(mtlDevice: device),
data: array.rawData,
shape: (shape ?? [array.count]).nsnumbers,
shape: shape.nsnumbers,
dataType: .float32
)
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/MPSXTests/FoundationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ final class FoundationTests: XCTestCase {
}
}

#if !arch(x86_64)
#if arch(arm64)
XCTAssertTrue(_test(body: { (x: [Float]) in FPC._Float32_Float16(x) }, compare: { abs($0 - $1) < .ulpOfOne }))
XCTAssertTrue(_test(body: { (x: [Float16]) in FPC._Float16_Float32(x) }, compare: { abs($0 - $1) < Float.ulpOfOne }))
#endif
Expand Down

0 comments on commit ab097be

Please sign in to comment.