Skip to content

Introduce mapError function #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions Sources/AsyncAlgorithms/AsyncMapErrorSequence.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2024 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

extension AsyncSequence {

/// Converts any failure into a new error.
///
/// - Parameter transform: A closure that takes the failure as a parameter and returns a new error.
/// - Returns: An asynchronous sequence that maps the error thrown into the one produced by the transform closure.
///
/// Use the ``mapError(_:)`` operator when you need to replace one error type with another.
public func mapError<ErrorType>(transform: @Sendable @escaping (Error) -> ErrorType) -> AsyncMapErrorSequence<Self, ErrorType> {
.init(base: self, transform: transform)
}
}

/// An asynchronous sequence that converts any failure into a new error.
public struct AsyncMapErrorSequence<Base: AsyncSequence, ErrorType: Error>: AsyncSequence {

public typealias AsyncIterator = Iterator
public typealias Element = Base.Element

private let base: Base
private let transform: @Sendable (Error) -> ErrorType

init(
base: Base,
transform: @Sendable @escaping (Error) -> ErrorType
) {
self.base = base
self.transform = transform
}

public func makeAsyncIterator() -> Iterator {
Iterator(
base: base.makeAsyncIterator(),
transform: transform
)
}
}

extension AsyncMapErrorSequence {

/// The iterator that produces elements of the map sequence.
public struct Iterator: AsyncIteratorProtocol {

public typealias Element = Base.Element

private var base: Base.AsyncIterator

private let transform: @Sendable (Error) -> ErrorType

init(
base: Base.AsyncIterator,
transform: @Sendable @escaping (Error) -> ErrorType
) {
self.base = base
self.transform = transform
}

#if compiler(>=6.0)
public mutating func next() async throws(ErrorType) -> Element? {
do {
return try await base.next()
} catch {
throw transform(error)
}
}

@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
public mutating func next(isolation actor: isolated (any Actor)?) async throws(ErrorType) -> Element? {
do {
return try await base.next(isolation: actor)
} catch {
throw transform(error)
}
}
#else
public mutating func next() async throws -> Element? {
do {
return try await base.next()
} catch {
throw transform(error)
}
}
#endif
}
}

extension AsyncMapErrorSequence: Sendable where Base: Sendable, Base.Element: Sendable {}
89 changes: 89 additions & 0 deletions Tests/AsyncAlgorithmsTests/TestMapError.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import AsyncAlgorithms
import XCTest

final class TestMapError: XCTestCase {

func test_mapError() async throws {
let array = [URLError(.badURL)]
let sequence = array.async
.map { throw $0 }
.mapError { _ in
MyAwesomeError()
}

do {
for try await _ in sequence {
XCTFail("sequence should throw")
}
} catch {
#if compiler(>=6.0)
// NO-OP
// The compiler already checks that for us since we're using typed throws.
// Writing that assert will just give compiler warning.
#else
XCTAssert(error is MyAwesomeError)
#endif
}
}

func test_nonThrowing() async throws {
let array = [1, 2, 3, 4, 5]
let sequence = array.async
.mapError { _ in
MyAwesomeError()
}

var actual: [Int] = []
for try await value in sequence {
actual.append(value)
}
XCTAssertEqual(array, actual)
}

func test_cancellation() async throws {
let source = Indefinite(value: "test").async
let sequence = source.mapError { _ in MyAwesomeError() }

let finished = expectation(description: "finished")
let iterated = expectation(description: "iterated")

let task = Task {
var firstIteration = false
for try await el in sequence {
XCTAssertEqual(el, "test")

if !firstIteration {
firstIteration = true
iterated.fulfill()
}
}
finished.fulfill()
}

// ensure the other task actually starts
await fulfillment(of: [iterated], timeout: 1.0)
// cancellation should ensure the loop finishes
// without regards to the remaining underlying sequence
task.cancel()
await fulfillment(of: [finished], timeout: 1.0)
}

func test_empty() async throws {
let array: [Int] = []
let sequence = array.async
.mapError { _ in
MyAwesomeError()
}

var actual: [Int] = []
for try await value in sequence {
actual.append(value)
}
XCTAssert(actual.isEmpty)
}
}

private extension TestMapError {

struct MyAwesomeError: Error {}
}