diff --git a/Sources/AsyncAlgorithms/AsyncBroadcastSequence.swift b/Sources/AsyncAlgorithms/AsyncBroadcastSequence.swift new file mode 100644 index 00000000..beb6def8 --- /dev/null +++ b/Sources/AsyncAlgorithms/AsyncBroadcastSequence.swift @@ -0,0 +1,246 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +import DequeModule + +extension AsyncSequence where Self: Sendable, Element: Sendable { + public func broadcast() -> AsyncBroadcastSequence { + AsyncBroadcastSequence(self) + } +} + +public struct AsyncBroadcastSequence: Sendable where Base: Sendable, Base.Element: Sendable { + struct State : Sendable { + enum Terminal { + case failure(Error) + case finished + } + + struct Side { + var buffer = Deque() + var terminal: Terminal? + var continuation: UnsafeContinuation, Never>? + + mutating func drain() { + if !buffer.isEmpty, let continuation { + let element = buffer.removeFirst() + continuation.resume(returning: .success(element)) + self.continuation = nil + } else if let terminal, let continuation { + switch terminal { + case .failure(let error): + self.terminal = .finished + continuation.resume(returning: .failure(error)) + case .finished: + continuation.resume(returning: .success(nil)) + } + self.continuation = nil + } + } + + mutating func cancel() { + buffer.removeAll() + terminal = .finished + drain() + } + + mutating func next(_ continuation: UnsafeContinuation, Never>) { + assert(self.continuation == nil) // presume that the sides are NOT sendable iterators... + self.continuation = continuation + drain() + } + + mutating func emit(_ result: Result) { + switch result { + case .success(let element): + if let element { + buffer.append(element) + } else { + terminal = .finished + } + case .failure(let error): + terminal = .failure(error) + } + drain() + } + } + + var id = 0 + var sides = [Int: Side]() + + init() { } + + mutating func establish() -> Int { + defer { id += 1 } + sides[id] = Side() + return id + } + + static func establish(_ state: ManagedCriticalState) -> Int { + state.withCriticalRegion { $0.establish() } + } + + mutating func cancel(_ id: Int) { + if var side = sides.removeValue(forKey: id) { + side.cancel() + } + } + + static func cancel(_ state: ManagedCriticalState, id: Int) { + state.withCriticalRegion { $0.cancel(id) } + } + + mutating func next(_ id: Int, continuation: UnsafeContinuation, Never>) { + sides[id]?.next(continuation) + } + + static func next(_ state: ManagedCriticalState, id: Int) async -> Result { + await withUnsafeContinuation { continuation in + state.withCriticalRegion { $0.next(id, continuation: continuation) } + } + } + + mutating func emit(_ result: Result) { + for id in sides.keys { + sides[id]?.emit(result) + } + } + + static func emit(_ state: ManagedCriticalState, result: Result) { + state.withCriticalRegion { $0.emit(result) } + } + } + + struct Iteration { + enum Status { + case initial(Base) + case iterating(Task) + case terminal + } + + var status: Status + + init(_ base: Base) { + status = .initial(base) + } + + static func task(_ state: ManagedCriticalState, base: Base) -> Task { + Task { + do { + for try await element in base { + State.emit(state, result: .success(element)) + } + State.emit(state, result: .success(nil)) + } catch { + State.emit(state, result: .failure(error)) + } + } + } + + mutating func start(_ state: ManagedCriticalState) -> Bool { + switch status { + case .terminal: + return false + case .initial(let base): + status = .iterating(Iteration.task(state, base: base)) + default: + break + } + return true + } + + mutating func cancel() { + switch status { + case .iterating(let task): + task.cancel() + default: + break + } + status = .terminal + } + + static func start(_ iteration: ManagedCriticalState, state: ManagedCriticalState) -> Bool { + iteration.withCriticalRegion { $0.start(state) } + } + + static func cancel(_ iteration: ManagedCriticalState) { + iteration.withCriticalRegion { $0.cancel() } + } + } + + let state: ManagedCriticalState + let iteration: ManagedCriticalState + + init(_ base: Base) { + state = ManagedCriticalState(State()) + iteration = ManagedCriticalState(Iteration(base)) + } +} + + +extension AsyncBroadcastSequence: AsyncSequence { + public typealias Element = Base.Element + + public struct Iterator: AsyncIteratorProtocol { + final class Context { + let state: ManagedCriticalState + var iteration: ManagedCriticalState + let id: Int + + init(_ state: ManagedCriticalState, _ iteration: ManagedCriticalState) { + self.state = state + self.iteration = iteration + self.id = State.establish(state) + } + + deinit { + State.cancel(state, id: id) + if iteration.isKnownUniquelyReferenced() { + Iteration.cancel(iteration) + } + } + + func next() async rethrows -> Element? { + guard Iteration.start(iteration, state: state) else { + return nil + } + defer { + if Task.isCancelled && iteration.isKnownUniquelyReferenced() { + Iteration.cancel(iteration) + } + } + return try await withTaskCancellationHandler { + let result = await State.next(state, id: id) + return try result._rethrowGet() + } onCancel: { [state, id] in + State.cancel(state, id: id) + } + } + } + + let context: Context + + init(_ state: ManagedCriticalState, _ iteration: ManagedCriticalState) { + context = Context(state, iteration) + } + + public mutating func next() async rethrows -> Element? { + try await context.next() + } + } + + public func makeAsyncIterator() -> Iterator { + Iterator(state, iteration) + } +} + +@available(*, unavailable) +extension AsyncBroadcastSequence.Iterator: Sendable { } diff --git a/Sources/AsyncAlgorithms/Locking.swift b/Sources/AsyncAlgorithms/Locking.swift index 74396080..0d6f6ab6 100644 --- a/Sources/AsyncAlgorithms/Locking.swift +++ b/Sources/AsyncAlgorithms/Locking.swift @@ -117,7 +117,7 @@ struct ManagedCriticalState { } } - private let buffer: ManagedBuffer + private var buffer: ManagedBuffer init(_ initial: State) { buffer = LockedBuffer.create(minimumCapacity: 1) { buffer in @@ -133,6 +133,10 @@ struct ManagedCriticalState { return try critical(&header.pointee) } } + + mutating func isKnownUniquelyReferenced() -> Bool { + Swift.isKnownUniquelyReferenced(&buffer) + } } extension ManagedCriticalState: @unchecked Sendable where State: Sendable { } diff --git a/Tests/AsyncAlgorithmsTests/TestBroadcast.swift b/Tests/AsyncAlgorithmsTests/TestBroadcast.swift new file mode 100644 index 00000000..aa479751 --- /dev/null +++ b/Tests/AsyncAlgorithmsTests/TestBroadcast.swift @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +@preconcurrency import XCTest +import AsyncAlgorithms + +final class TestBroadcast: XCTestCase { + func test_basic_broadcasting() async { + let base = [1, 2, 3, 4].async + let a = base.broadcast() + let b = a + let results = await withTaskGroup(of: [Int].self) { group in + group.addTask { + await Array(a) + } + group.addTask { + await Array(b) + } + return await Array(group) + } + XCTAssertEqual(results[0], results[1]) + } + + func test_basic_broadcasting_from_channel() async { + let base = AsyncChannel() + let a = base.broadcast() + let b = a + let results = await withTaskGroup(of: [Int].self) { group in + group.addTask { + var sent = [Int]() + for i in 0..<10 { + sent.append(i) + await base.send(i) + } + base.finish() + return sent + } + group.addTask { + await Array(a) + } + group.addTask { + await Array(b) + } + return await Array(group) + } + XCTAssertEqual(results[0], results[1]) + } +}