Skip to content

Implement AsyncCombineLatestMultipleSequence #322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2024 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

// MARK: - Public interface

/// Creates an asynchronous sequence that combines the latest values from multiple ``AsyncSequence`` with the same element type
/// by emitting an array of the values.
///
/// The new asynchronous sequence only emits a value whenever any of the base ``AsyncSequence``s
/// emit a value (so long as each of the bases have emitted at least one value).
///
/// - Important: It finishes when one of the bases finishes before emitting any value or when all bases finished.
///
/// - Throws: It throws when one of the bases throws.
///
/// - Note: This function requires the return type to be the same for all ``AsyncSequence``.
public func combineLatest<Sequence: AsyncSequence, ElementOfResult: Sendable>(_ sequences: [Sequence]) -> AsyncThrowingStream<[ElementOfResult], Error> where Sequence.Element == ElementOfResult, Sequence: Sendable {
AsyncCombineLatestMultipleSequence(sequences: sequences).stream
}

/// Creates an asynchronous sequence that combines the latest values from multiple ``AsyncSequence`` with the same element type
/// by emitting an array of the values.
///
/// The new asynchronous sequence only emits a value whenever any of the base ``AsyncSequence``s
/// emit a value (so long as each of the bases have emitted at least one value).
///
/// - Important: It finishes when one of the bases finishes before emitting any value or when all bases finished.
///
/// - Throws: It throws when one of the bases throws.
///
/// - Note: This function requires the return type to be the same for all ``AsyncSequence``.
public func combineLatest<Sequence: AsyncSequence, ElementOfResult: Sendable>(_ sequences: Sequence...) -> AsyncThrowingStream<[ElementOfResult], Error> where Sequence.Element == ElementOfResult, Sequence: Sendable {
AsyncCombineLatestMultipleSequence(sequences: sequences).stream
}

// MARK: - Private helpers

fileprivate final class AsyncCombineLatestMultipleSequence<Sequence: AsyncSequence, ElementOfResult: Sendable>: Sendable where Sequence.Element == ElementOfResult, Sequence: Sendable {

private let results: ManagedCriticalState<[State]>
private let continuation: AsyncThrowingStream<[ElementOfResult], Error>.Continuation

fileprivate let stream: AsyncThrowingStream<[ElementOfResult], Error>

fileprivate init(sequences: [Sequence]) {
self.results = .init(
Array(
repeating: State.initial,
count: sequences.count
)
)

(self.stream, self.continuation) = AsyncThrowingStream<[ElementOfResult], Error>.makeStream()

let task = Task {
await withTaskGroup(of: Void.self) { group in
for (index, sequence) in sequences.enumerated() {
group.addTask {
do {
var lastKnownValue: ElementOfResult?
for try await value in sequence {
self.set(state: .succeeded(value), at: index)
lastKnownValue = value
}
self.set(state: .finished(lastKnownValue: lastKnownValue), at: index)
} catch {
self.set(state: .failed(error), at: index)
}
}
}
}
}

continuation.onTermination = { _ in
task.cancel()
}
}
}

private extension AsyncCombineLatestMultipleSequence {

func set(state: State, at index: Int) {
results.withCriticalRegion { array in
array[index] = state

var allFinished = true
var latestResults: [ElementOfResult] = []
latestResults.reserveCapacity(array.count)

for state in array {
switch state {
case .initial:
// Only emit updates when all have value.
return

case .succeeded(let elementOfResult):
latestResults.append(elementOfResult)
allFinished = false

case .failed(let error):
continuation.finish(throwing: error)
return

case .finished(let lastKnownValue):
if let lastKnownValue {
latestResults.append(lastKnownValue)
} else {
// If `lastKnownValue` is nil,
// that means the async sequence finished before emitting any value.
// And we'll never be able to complete the entire array.
continuation.finish()
return
}
}
}

if allFinished {
continuation.finish()
} else if case .succeeded = state {
continuation.yield(latestResults)
}
}
}
}

// MARK: - Type definitions

private extension AsyncCombineLatestMultipleSequence {

enum State {
case initial
case succeeded(ElementOfResult)
case failed(Error)
case finished(lastKnownValue: ElementOfResult?)
}
}
83 changes: 83 additions & 0 deletions Tests/AsyncAlgorithmsTests/TestCombineLatest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,86 @@ final class TestCombineLatest3: XCTestCase {
XCTAssertEqual(value, [(1, "a", 4), (2, "a", 4), (2, "b", 4), (2, "b", 5), (3, "b", 5), (3, "c", 5), (3, "c", 6)])
}
}

final class TestCombineLatestMultiple: XCTestCase {

func test_CorrectOrdering() async throws {
let s1 = AsyncStream {
$0.yield(1)
$0.finish()
}

let s2 = AsyncStream { continuation in
Task {
continuation.yield(2)
try? await Task.sleep(nanoseconds: 20_000_000)
continuation.yield(5)
continuation.finish()
}
}

let s3 = AsyncStream { continuation in
Task {
continuation.yield(3)
try? await Task.sleep(nanoseconds: 10_000_000)
continuation.yield(4)
continuation.finish()
}
}

let s4 = AsyncStream {
$0.yield(0)
$0.finish()
}

let expectedResult = [
[1, 2, 3, 0],
[1, 2, 4, 0],
[1, 5, 4, 0]
]
var expectedResultIterator = expectedResult.makeIterator()

for try await result in combineLatest(s1, s2, s3, s4) {
XCTAssertEqual(result, expectedResultIterator.next())
}
}

func test_EarlyReturn() async throws {
let s1 = AsyncStream {
$0.yield(1)
$0.finish()
}

let s2 = AsyncStream<Int> {
$0.finish()
}

for try await _ in combineLatest(s1, s2, s1, s1) {
XCTFail("`combineLatest` shouldn't return any value as s2 never emits any value.")
}
}

func test_ThrowingPropagation() async throws {
let s1 = AsyncThrowingStream<Int, Error> {
$0.yield(1)
$0.finish()
}

let s2 = AsyncThrowingStream<Int, Error> {
$0.finish(throwing: Some.error)
}

do {
for try await _ in combineLatest(s1, s2, s1, s1) {
XCTFail("Expects error to be thrown")
}
} catch {
let error = try XCTUnwrap(error as? Some)
XCTAssertEqual(error, .error)
}
}

private enum Some: Error {
case error
}
}