From ab097bec80acdb78afdde376eb7c9b6ddbbce6bb Mon Sep 17 00:00:00 2001 From: Geor Kasapidi Date: Fri, 24 Feb 2023 14:03:10 +0200 Subject: [PATCH] micro update --- Sources/MPSX/Foundation/DSL.swift | 2 +- Sources/MPSX/Foundation/FPC.swift | 2 +- Sources/MPSX/Foundation/Utilities.swift | 17 +++++++++++++++-- Sources/MPSXTests/FoundationTests.swift | 2 +- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/Sources/MPSX/Foundation/DSL.swift b/Sources/MPSX/Foundation/DSL.swift index e5699e8..4938610 100644 --- a/Sources/MPSX/Foundation/DSL.swift +++ b/Sources/MPSX/Foundation/DSL.swift @@ -6,7 +6,7 @@ public extension MPSGraph { constant(vector.rawData, shape: (shape ?? [vector.count]).nsnumbers, dataType: .float32) } - #if !arch(x86_64) + #if arch(arm64) @inlinable func const(_ vector: [Swift.Float16], shape: [Int]? = nil) -> MPSGraphTensor { constant(vector.rawData, shape: (shape ?? [vector.count]).nsnumbers, dataType: .float16) diff --git a/Sources/MPSX/Foundation/FPC.swift b/Sources/MPSX/Foundation/FPC.swift index 892fa4d..85aafbe 100644 --- a/Sources/MPSX/Foundation/FPC.swift +++ b/Sources/MPSX/Foundation/FPC.swift @@ -1,7 +1,7 @@ import Accelerate import Foundation -#if arch(x86_64) +#if !arch(arm64) typealias Float16 = UInt16 #endif diff --git a/Sources/MPSX/Foundation/Utilities.swift b/Sources/MPSX/Foundation/Utilities.swift index 790c5b7..c97338f 100644 --- a/Sources/MPSX/Foundation/Utilities.swift +++ b/Sources/MPSX/Foundation/Utilities.swift @@ -42,6 +42,15 @@ extension Array { } } +extension ArraySlice { + @usableFromInline + var rawData: Data { + withUnsafeBufferPointer { + Data(buffer: $0) // copy + } + } +} + extension Data { func mapMemory(of _: T.Type, _ body: (UnsafeBufferPointer) throws -> R) rethrows -> R { try withUnsafeBytes { @@ -182,10 +191,14 @@ private extension MPSGraphTensorData { public extension MPSGraphTensorData { static func floats(_ array: [Float], shape: [Int]? = nil, device: MTLDevice) -> MPSGraphTensorData { - .init( + let shape = shape ?? [array.count] + + assert(shape.reduce(1, *) == array.count) + + return .init( device: .init(mtlDevice: device), data: array.rawData, - shape: (shape ?? [array.count]).nsnumbers, + shape: shape.nsnumbers, dataType: .float32 ) } diff --git a/Sources/MPSXTests/FoundationTests.swift b/Sources/MPSXTests/FoundationTests.swift index 7c39885..6708e5a 100644 --- a/Sources/MPSXTests/FoundationTests.swift +++ b/Sources/MPSXTests/FoundationTests.swift @@ -30,7 +30,7 @@ final class FoundationTests: XCTestCase { } } - #if !arch(x86_64) + #if arch(arm64) XCTAssertTrue(_test(body: { (x: [Float]) in FPC._Float32_Float16(x) }, compare: { abs($0 - $1) < .ulpOfOne })) XCTAssertTrue(_test(body: { (x: [Float16]) in FPC._Float16_Float32(x) }, compare: { abs($0 - $1) < Float.ulpOfOne })) #endif