From bc37a1c5b14aa5b70db7f471a5c9bf6d03949996 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 5 Apr 2023 14:37:23 +0100 Subject: [PATCH 1/2] Add AsyncSequence.enumerated() --- .../AsyncEnumeratedSequence.swift | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 Sources/AsyncAlgorithms/AsyncEnumeratedSequence.swift diff --git a/Sources/AsyncAlgorithms/AsyncEnumeratedSequence.swift b/Sources/AsyncAlgorithms/AsyncEnumeratedSequence.swift new file mode 100644 index 00000000..81c4d7cc --- /dev/null +++ b/Sources/AsyncAlgorithms/AsyncEnumeratedSequence.swift @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + + +/// An enumeration of the elements of an AsyncSequence. +/// +/// `AsyncEnumeratedSequence` generates a sequence of pairs (*n*, *x*), where *n*s are +/// consecutive `Int` values starting at zero, and *x*s are the elements from an +/// base AsyncSequence. +/// +/// To create an instance of `EnumeratedSequence`, call `enumerated()` on an +/// AsyncSequence. +public struct AsyncEnumeratedSequence { + @usableFromInline + let base: Base + + @usableFromInline + init(_ base: Base) { + self.base = base + } +} + +extension AsyncEnumeratedSequence: AsyncSequence { + public typealias Element = (Int, Base.Element) + + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline + var baseIterator: Base.AsyncIterator + @usableFromInline + var index: Int + + @usableFromInline + init(baseIterator: Base.AsyncIterator) { + self.baseIterator = baseIterator + self.index = 0 + } + + @inlinable + public mutating func next() async rethrows -> AsyncEnumeratedSequence.Element? { + let value = try await self.baseIterator.next().map { (self.index, $0) } + self.index += 1 + return value + } + } + + @inlinable + public __consuming func makeAsyncIterator() -> AsyncIterator { + return .init(baseIterator: self.base.makeAsyncIterator()) + } +} + +extension AsyncEnumeratedSequence: Sendable where Base: Sendable {} + +extension AsyncSequence { + /// Return an enumaterated AsyncSequence + public func enumerated() -> AsyncEnumeratedSequence { return AsyncEnumeratedSequence(self) } +} From 2a9903ee1664c218ad1d57ff97e331a0e52f79f2 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 5 Apr 2023 14:37:43 +0100 Subject: [PATCH 2/2] Add tests for AsyncSequnce.enumerated() --- .../AsyncAlgorithmsTests/TestEnumerate.swift | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 Tests/AsyncAlgorithmsTests/TestEnumerate.swift diff --git a/Tests/AsyncAlgorithmsTests/TestEnumerate.swift b/Tests/AsyncAlgorithmsTests/TestEnumerate.swift new file mode 100644 index 00000000..7aaec03a --- /dev/null +++ b/Tests/AsyncAlgorithmsTests/TestEnumerate.swift @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +import XCTest +import AsyncAlgorithms +import AsyncSequenceValidation + +final class TestEnumerated: XCTestCase { + func testEnumerate() async { + let source = ["a", "b", "c", "d"] + let enumerated = source.async.enumerated() + var actual = [(Int, String)]() + var iterator = enumerated.makeAsyncIterator() + while let item = await iterator.next() { + actual.append(item) + } + XCTAssertEqual(actual, .init(source.enumerated())) + let pastEnd = await iterator.next() + XCTAssertNil(pastEnd) + } + + func testEmpty() async { + let source = [String]() + let enumerated = source.async.enumerated() + var iterator = enumerated.makeAsyncIterator() + let pastEnd = await iterator.next() + XCTAssertNil(pastEnd) + } + + func testEnumeratedThrowsWhenBaseSequenceThrows() async throws { + let sequence = ["a", "b", "c", "d"].async.map { try throwOn("c", $0) }.enumerated() + var iterator = sequence.makeAsyncIterator() + var collected = [(Int, String)]() + do { + while let item = try await iterator.next() { + collected.append(item) + } + XCTFail() + } catch { + XCTAssertEqual(error as? Failure, Failure()) + } + XCTAssertEqual(collected, [(0, "a"), (1, "b")]) + + let pastEnd = try await iterator.next() + XCTAssertNil(pastEnd) + } + + func testEnumeratedFinishesWhenCancelled() { + let source = Indefinite(value: "a") + let sequence = source.async.enumerated() + let finished = expectation(description: "finished") + let iterated = expectation(description: "iterated") + let task = Task { + var firstIteration = false + for await _ in sequence { + if !firstIteration { + firstIteration = true + iterated.fulfill() + } + } + finished.fulfill() + } + // ensure the other task actually starts + wait(for: [iterated], timeout: 1.0) + // cancellation should ensure the loop finishes + // without regards to the remaining underlying sequence + task.cancel() + wait(for: [finished], timeout: 1.0) + } +}