diff --git a/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift b/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift index bb832ada..5c6fd8fe 100644 --- a/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift +++ b/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift @@ -41,7 +41,8 @@ struct MergeStateMachine< buffer: Deque, upstreamContinuations: [UnsafeContinuation], upstreamsFinished: Int, - downstreamContinuation: UnsafeContinuation? + downstreamContinuation: UnsafeContinuation?, + cancelled: Bool ) /// The state once any of the upstream sequences threw an `Error`. @@ -100,11 +101,11 @@ struct MergeStateMachine< // Nothing to do here. No demand was signalled until now return .none - case .merging(_, _, _, _, .some): + case .merging(_, _, _, _, .some, _): // An iterator was deinitialized while we have a suspended continuation. preconditionFailure("Internal inconsistency current state \(self.state) and received iteratorDeinitialized()") - case let .merging(task, _, upstreamContinuations, _, .none): + case let .merging(task, _, upstreamContinuations, _, .none, _): // The iterator was dropped which signals that the consumer is finished. // We can transition to finished now and need to clean everything up. state = .finished @@ -142,7 +143,8 @@ struct MergeStateMachine< buffer: .init(), upstreamContinuations: [], // This should reserve capacity in the variadic generics case upstreamsFinished: 0, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: false ) case .merging, .upstreamFailure, .finished: @@ -175,11 +177,11 @@ struct MergeStateMachine< // Child tasks are only created after we transitioned to `merging` preconditionFailure("Internal inconsistency current state \(self.state) and received childTaskSuspended()") - case .merging(_, _, _, _, .some): + case .merging(_, _, _, _, .some, _): // We have outstanding demand so request the next element return .resumeContinuation(upstreamContinuation: continuation) - case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none): + case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none, let cancelled): // There is no outstanding demand from the downstream // so we are storing the continuation and resume it once there is demand. state = .modifying @@ -191,7 +193,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .none @@ -236,7 +239,7 @@ struct MergeStateMachine< // Child tasks that are producing elements are only created after we transitioned to `merging` preconditionFailure("Internal inconsistency current state \(self.state) and received elementProduced()") - case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation)): + case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation), cancelled): // We produced an element and have an outstanding downstream continuation // this means we can go right ahead and resume the continuation with that element precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") @@ -246,7 +249,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .resumeContinuation( @@ -254,7 +258,7 @@ struct MergeStateMachine< element: element ) - case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none): + case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none, let cancelled): // There is not outstanding downstream continuation so we must buffer the element // This happens if we race our upstream sequences to produce elements // and the _losers_ are signalling their produced element @@ -267,7 +271,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .none @@ -310,7 +315,7 @@ struct MergeStateMachine< case .initial: preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamFinished()") - case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation)): + case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation), let cancelled): // One of the upstreams finished precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") @@ -335,13 +340,14 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: downstreamContinuation + downstreamContinuation: downstreamContinuation, + cancelled: cancelled ) return .none } - case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none): + case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none, let cancelled): // First we increment our counter of finished upstreams upstreamsFinished += 1 @@ -350,7 +356,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) if upstreamsFinished == self.numberOfUpstreamSequences { @@ -402,7 +409,7 @@ struct MergeStateMachine< case .initial: preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamThrew()") - case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation)): + case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation), _): // An upstream threw an error and we have a downstream continuation. // We just need to resume the downstream continuation with the error and cancel everything precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") @@ -417,7 +424,7 @@ struct MergeStateMachine< upstreamContinuations: upstreamContinuations ) - case let .merging(task, buffer, upstreamContinuations, _, .none): + case let .merging(task, buffer, upstreamContinuations, _, .none, _): // An upstream threw an error and we don't have a downstream continuation. // We need to store the error and wait for the downstream to consume the // rest of the buffer and the error. However, we can already cancel the task @@ -454,10 +461,7 @@ struct MergeStateMachine< upstreamContinuations: [UnsafeContinuation] ) /// Indicates that the task and the upstream continuations should be cancelled. - case cancelTaskAndUpstreamContinuations( - task: Task, - upstreamContinuations: [UnsafeContinuation] - ) + case cancelTask(Task) /// Indicates that nothing should be done. case none } @@ -471,26 +475,21 @@ struct MergeStateMachine< return .none - case let .merging(task, _, upstreamContinuations, _, .some(downstreamContinuation)): - // The downstream Task got cancelled so we need to cancel our upstream Task - // and resume all continuations. We can also transition to finished. - state = .finished + case let .merging(task, buffer, upstreamContinuations, upstreamFinished, downstreamContinuation, cancelled): + guard !cancelled else { + return .none + } - return .resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations( - downstreamContinuation: downstreamContinuation, + self.state = .merging( task: task, - upstreamContinuations: upstreamContinuations + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamFinished, + downstreamContinuation: downstreamContinuation, + cancelled: true ) - case let .merging(task, _, upstreamContinuations, _, .none): - // The downstream Task got cancelled so we need to cancel our upstream Task - // and resume all continuations. We can also transition to finished. - state = .finished - - return .cancelTaskAndUpstreamContinuations( - task: task, - upstreamContinuations: upstreamContinuations - ) + return .cancelTask(task) case .upstreamFailure: // An upstream already threw and we cancelled everything already. @@ -531,11 +530,11 @@ struct MergeStateMachine< // We are transitioning to merging in the taskStarted method. return .startTaskAndSuspendDownstreamTask(base1, base2, base3) - case .merging(_, _, _, _, .some): + case .merging(_, _, _, _, .some, _): // We have multiple AsyncIterators iterating the sequence preconditionFailure("Internal inconsistency current state \(self.state) and received next()") - case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none): + case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none, let cancelled): state = .modifying if let element = buffer.popFirst() { @@ -545,7 +544,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .returnElement(.success(element)) @@ -556,7 +556,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .suspendDownstreamTask @@ -601,21 +602,22 @@ struct MergeStateMachine< mutating func next(for continuation: UnsafeContinuation) -> NextForAction { switch state { case .initial, - .merging(_, _, _, _, .some), + .merging(_, _, _, _, .some, _), .upstreamFailure, .finished: // All other states are handled by `next` already so we should never get in here with // any of those preconditionFailure("Internal inconsistency current state \(self.state) and received next(for:)") - case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none): + case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none, cancelled): // We suspended the task and need signal the upstreams state = .merging( task: task, buffer: buffer, upstreamContinuations: [], // TODO: don't alloc new array here upstreamsFinished: upstreamsFinished, - downstreamContinuation: continuation + downstreamContinuation: continuation, + cancelled: cancelled ) return .resumeUpstreamContinuations( diff --git a/Sources/AsyncAlgorithms/Merge/MergeStorage.swift b/Sources/AsyncAlgorithms/Merge/MergeStorage.swift index 9dedee76..443c95cd 100644 --- a/Sources/AsyncAlgorithms/Merge/MergeStorage.swift +++ b/Sources/AsyncAlgorithms/Merge/MergeStorage.swift @@ -128,12 +128,7 @@ final class MergeStorage< downstreamContinuation.resume(returning: nil) - case let .cancelTaskAndUpstreamContinuations( - task, - upstreamContinuations - ): - upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } - + case let .cancelTask(task): task.cancel() case .none: @@ -262,8 +257,8 @@ final class MergeStorage< task, upstreamContinuations ): - upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } task.cancel() + upstreamContinuations.forEach { $0.resume() } downstreamContinuation.resume(returning: nil) @@ -273,8 +268,8 @@ final class MergeStorage< task, upstreamContinuations ): - upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } task.cancel() + upstreamContinuations.forEach { $0.resume() } break loop case .none: diff --git a/Tests/AsyncAlgorithmsTests/TestMerge.swift b/Tests/AsyncAlgorithmsTests/TestMerge.swift index c8d5e1ce..293e74a8 100644 --- a/Tests/AsyncAlgorithmsTests/TestMerge.swift +++ b/Tests/AsyncAlgorithmsTests/TestMerge.swift @@ -201,6 +201,38 @@ final class TestMerge2: XCTestCase { } t.cancel() } + + func testAsyncStreamElementsThatAreInjectedOnCancellationAreDelivered() async { + let (stream1, continuation1) = AsyncStream.makeStream(of: Int.self) + let (stream2, continuation2) = AsyncStream.makeStream(of: Int.self) + continuation1.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation1.yield(1) + } + continuation2.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation2.yield(2) + } + continuation1.yield(0) // initial + let merge = merge(stream1, stream2) + let finished = expectation(description: "finished") + let iterated = expectation(description: "iterated") + let task = Task { + var count = 0 + for await _ in merge { + if count == 0 { iterated.fulfill() } + count += 1 + } + finished.fulfill() + XCTAssertEqual(count, 3) + } + // ensure the other task actually starts + await fulfillment(of: [iterated], timeout: 1.0) + // cancellation should ensure the loop finishes + // without regards to the remaining underlying sequence + task.cancel() + await fulfillment(of: [finished], timeout: 1.0) + } } final class TestMerge3: XCTestCase { @@ -555,4 +587,41 @@ final class TestMerge3: XCTestCase { iterator = nil } + + func testAsyncStreamElementsThatAreInjectedOnCancellationAreDelivered() async { + let (stream1, continuation1) = AsyncStream.makeStream(of: Int.self) + let (stream2, continuation2) = AsyncStream.makeStream(of: Int.self) + let (stream3, continuation3) = AsyncStream.makeStream(of: Int.self) + continuation1.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation1.yield(1) + } + continuation2.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation2.yield(2) + } + continuation3.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation3.yield(3) + } + continuation1.yield(0) // initial + let merge = merge(stream1, stream2, stream3) + let finished = expectation(description: "finished") + let iterated = expectation(description: "iterated") + let task = Task { + var count = 0 + for await _ in merge { + if count == 0 { iterated.fulfill() } + count += 1 + } + finished.fulfill() + XCTAssertEqual(count, 4) + } + // ensure the other task actually starts + await fulfillment(of: [iterated], timeout: 1.0) + // cancellation should ensure the loop finishes + // without regards to the remaining underlying sequence + task.cancel() + await fulfillment(of: [finished], timeout: 1.0) + } }