diff --git a/FlyingSocks/Sources/Task+Timeout.swift b/FlyingSocks/Sources/Task+Timeout.swift index bb869e9d..da9d9754 100644 --- a/FlyingSocks/Sources/Task+Timeout.swift +++ b/FlyingSocks/Sources/Task+Timeout.swift @@ -1,6 +1,6 @@ // // TaskTimeout.swift -// TaskTimeout +// swift-timeout // // Created by Simon Whitty on 31/08/2024. // Copyright 2024 Simon Whitty @@ -8,7 +8,7 @@ // Distributed under the permissive MIT license // Get the latest version from here: // -// https://github.com/swhitty/TaskTimeout +// https://github.com/swhitty/swift-timeout // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -31,11 +31,11 @@ import Foundation -package struct TimeoutError: LocalizedError { - package var errorDescription: String? +public struct TimeoutError: LocalizedError { + public var errorDescription: String? - package init(timeout: TimeInterval) { - self.errorDescription = "Task timed out before completion. Timeout: \(timeout) seconds." + init(_ description: String) { + self.errorDescription = description } } @@ -45,34 +45,31 @@ package func withThrowingTimeout( seconds: TimeInterval, body: () async throws -> sending T ) async throws -> sending T { - let transferringBody = { try await Transferring(body()) } - typealias NonSendableClosure = () async throws -> Transferring - typealias SendableClosure = @Sendable () async throws -> Transferring - return try await withoutActuallyEscaping(transferringBody) { - (_ fn: @escaping NonSendableClosure) async throws -> Transferring in - let sendableFn = unsafeBitCast(fn, to: SendableClosure.self) - return try await _withThrowingTimeout(isolation: isolation, seconds: seconds, body: sendableFn) - }.value -} - -// Sendable -private func _withThrowingTimeout( - isolation: isolated (any Actor)? = #isolation, - seconds: TimeInterval, - body: @Sendable @escaping () async throws -> T -) async throws -> T { - try await withThrowingTaskGroup(of: T.self, isolation: isolation) { group in - group.addTask { - try await body() + try await withoutActuallyEscaping(body) { escapingBody in + let bodyTask = Task { + defer { _ = isolation } + return try await Transferring(escapingBody()) } - group.addTask { + let timeoutTask = Task { + defer { bodyTask.cancel() } try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) - throw TimeoutError(timeout: seconds) + throw TimeoutError("Task timed out before completion. Timeout: \(seconds) seconds.") } - let success = try await group.next()! - group.cancelAll() - return success - } + + let bodyResult = await withTaskCancellationHandler { + await bodyTask.result + } onCancel: { + bodyTask.cancel() + } + timeoutTask.cancel() + + if case .failure(let timeoutError) = await timeoutTask.result, + timeoutError is TimeoutError { + throw timeoutError + } else { + return try bodyResult.get() + } + }.value } #else package func withThrowingTimeout( @@ -100,7 +97,7 @@ private func _withThrowingTimeout( } group.addTask { try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) - throw TimeoutError(timeout: seconds) + throw TimeoutError("Task timed out before completion. Timeout: \(seconds) seconds.") } let success = try await group.next()! group.cancelAll() @@ -132,26 +129,13 @@ package extension Task { } case .afterTimeout(let seconds): if seconds > 0 { - return try await getValue(cancellingAfter: seconds) + return try await withThrowingTimeout(seconds: seconds) { + try await getValue(cancelling: .whenParentIsCancelled) + } } else { cancel() return try await value } } } - - private func getValue(cancellingAfter seconds: TimeInterval) async throws -> Success { - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - _ = try await getValue(cancelling: .whenParentIsCancelled) - } - group.addTask { - try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) - throw TimeoutError(timeout: seconds) - } - _ = try await group.next()! - group.cancelAll() - return try await value - } - } } diff --git a/FlyingSocks/Tests/SocketTests.swift b/FlyingSocks/Tests/SocketTests.swift index 766d93be..2cbe27cc 100644 --- a/FlyingSocks/Tests/SocketTests.swift +++ b/FlyingSocks/Tests/SocketTests.swift @@ -76,7 +76,7 @@ struct SocketTests { try s1.close() try s2.close() - #expect(throws: SocketError.disconnected) { + #expect(throws: (any Error).self) { try s1.read() } } diff --git a/FlyingSocks/Tests/Task+TimeoutTests.swift b/FlyingSocks/Tests/Task+TimeoutTests.swift index aa09ca54..a9fe00e5 100644 --- a/FlyingSocks/Tests/Task+TimeoutTests.swift +++ b/FlyingSocks/Tests/Task+TimeoutTests.swift @@ -50,7 +50,7 @@ struct TaskTimeoutTests { func timeoutThrowsError_WhenTimeoutExpires() async { // given let task = Task(timeout: 0.01) { - try? await Task.sleep(seconds: 10) + try await Task.sleep(seconds: 10) } // then @@ -141,30 +141,26 @@ struct TaskTimeoutTests { ) } - @MainActor - @Test + @Test @MainActor func mainActor_ReturnsValue() async throws { let val = try await withThrowingTimeout(seconds: 1) { - #if compiler(>=5.10) MainActor.assertIsolated() - #endif + try await Task.sleep(nanoseconds: 1_000) + MainActor.assertIsolated() return "Fish" } #expect(val == "Fish") } @Test - func mainActorThrowsError_WhenTimeoutExpires() async throws { - let task = Task { @MainActor in + func mainActorThrowsError_WhenTimeoutExpires() async { + await #expect(throws: TimeoutError.self) { @MainActor in try await withThrowingTimeout(seconds: 0.05) { MainActor.assertIsolated() - try? await Task.sleep(nanoseconds: 60_000_000_000) + defer { MainActor.assertIsolated() } + try await Task.sleep(nanoseconds: 60_000_000_000) } } - - await #expect(throws: TimeoutError.self) { - try await task.value - } } @Test @@ -186,17 +182,32 @@ struct TaskTimeoutTests { @Test func actor_ReturnsValue() async throws { - let val = try await TestActor().returningString("Fish") - #expect(val == "Fish") + #expect( + try await TestActor("Fish").returningValue() == "Fish" + ) } @Test func actorThrowsError_WhenTimeoutExpires() async { await #expect(throws: TimeoutError.self) { - _ = try await TestActor().returningString( - after: 60, - timeout: 0.05 - ) + try await withThrowingTimeout(seconds: 0.05) { + try await TestActor().returningValue(after: 60, timeout: 0.05) + } + } + } + + @Test + func timeout_cancels() async { + let task = Task { + try await withThrowingTimeout(seconds: 1) { + try await Task.sleep(nanoseconds: 1_000_000_000) + } + } + + task.cancel() + + await #expect(throws: CancellationError.self) { + try await task.value } } } @@ -206,9 +217,15 @@ extension Task where Success: Sendable, Failure == any Error { // Start a new Task with a timeout. init(priority: TaskPriority? = nil, timeout: TimeInterval, operation: @escaping @Sendable () async throws -> Success) { self = Task(priority: priority) { - try await withThrowingTimeout(seconds: timeout) { - try await operation() + do { + return try await withThrowingTimeout(seconds: timeout) { + try await operation() + } + } catch { + print(error) + throw error } + } } } @@ -227,19 +244,23 @@ public struct NonSendable { } } -private final actor TestActor { +private final actor TestActor { + + private var value: T - func returningString(_ string: String = "Fish", after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> String { - try await returningValue(string, after: sleep, timeout: timeout) + init(_ value: T) { + self.value = value } - func returningValue(_ value: T, after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T { + init() where T == String { + self.init("fish") + } + + func returningValue(after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T { try await withThrowingTimeout(seconds: timeout) { - if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { - assertIsolated() - } - try await Task.sleep(seconds: sleep) - return value + try await Task.sleep(nanoseconds: UInt64(sleep * 1_000_000_000)) + self.assertIsolated() + return self.value } } } diff --git a/FlyingSocks/XCTests/Task+TimeoutTests.swift b/FlyingSocks/XCTests/Task+TimeoutTests.swift index 438381d6..faf19ed3 100644 --- a/FlyingSocks/XCTests/Task+TimeoutTests.swift +++ b/FlyingSocks/XCTests/Task+TimeoutTests.swift @@ -47,7 +47,7 @@ final class TaskTimeoutTests: XCTestCase { func testTimeoutThrowsError_WhenTimeoutExpires() async { // given let task = Task(timeout: 0.5) { - try? await Task.sleep(seconds: 10) + try await Task.sleep(seconds: 10) } // then @@ -146,9 +146,9 @@ final class TaskTimeoutTests: XCTestCase { @MainActor func testMainActor_ReturnsValue() async throws { let val = try await withThrowingTimeout(seconds: 1) { - #if compiler(>=5.10) - MainActor.assertIsolated() - #endif + MainActor.safeAssertIsolated() + try await Task.sleep(nanoseconds: 1_000) + MainActor.safeAssertIsolated() return "Fish" } XCTAssertEqual(val, "Fish") @@ -158,10 +158,9 @@ final class TaskTimeoutTests: XCTestCase { func testMainActorThrowsError_WhenTimeoutExpires() async { do { try await withThrowingTimeout(seconds: 0.05) { - #if compiler(>=5.10) - MainActor.assertIsolated() - #endif - try? await Task.sleep(nanoseconds: 60_000_000_000) + MainActor.safeAssertIsolated() + defer { MainActor.safeAssertIsolated() } + try await Task.sleep(nanoseconds: 60_000_000_000) } XCTFail("Expected Error") } catch { @@ -185,13 +184,13 @@ final class TaskTimeoutTests: XCTestCase { } func testActor_ReturnsValue() async throws { - let val = try await TestActor().returningString("Fish") + let val = try await TestActor("Fish").returningValue() XCTAssertEqual(val, "Fish") } func testActorThrowsError_WhenTimeoutExpires() async { do { - _ = try await TestActor().returningString( + _ = try await TestActor().returningValue( after: 60, timeout: 0.05 ) @@ -200,6 +199,23 @@ final class TaskTimeoutTests: XCTestCase { XCTAssertTrue(error is TimeoutError) } } + + func testTimeout_Cancels() async { + let task = Task { + try await withThrowingTimeout(seconds: 1) { + try await Task.sleep(nanoseconds: 1_000_000_000) + } + } + + task.cancel() + + do { + _ = try await task.value + XCTFail("Expected Error") + } catch { + XCTAssertTrue(error is CancellationError) + } + } } extension Task where Success: Sendable, Failure == any Error { @@ -228,19 +244,36 @@ public struct NonSendable { } } -private final actor TestActor { +private final actor TestActor { + + private var value: T + + init(_ value: T) { + self.value = value + } - func returningString(_ string: String = "Fish", after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> String { - try await returningValue(string, after: sleep, timeout: timeout) + init() where T == String { + self.init("fish") } - func returningValue(_ value: T, after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T { + func returningValue(after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T { try await withThrowingTimeout(seconds: timeout) { - if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { - assertIsolated() - } - try await Task.sleep(seconds: sleep) - return value + try await Task.sleep(nanoseconds: UInt64(sleep * 1_000_000_000)) + #if compiler(>=5.10) + self.assertIsolated() + #endif + return self.value } } } + +private extension MainActor { + + static func safeAssertIsolated() { + #if compiler(>=5.10) + assertIsolated() + #else + precondition(Thread.isMainThread) + #endif + } +}