Skip to content

Commit 2d60d9e

Browse files
authored
Support Rust async fn ret Result<(),enum> (#189)
This commit adds support for returning Result<(), TransparentEnum> from Rust async functions. Here's an example of using this: ```Rust //Rust #[swift_bridge::bridge] mod ffi { enum SomeEnum { //... } extern "Rust" { async fn some_function() -> Result<(), SomeEnum>; } } ``` ```Swift //Swift do { let value = try await some_function() //... } catch let error as SomeEnum { //... } catch { //... } ```
1 parent 68c3a38 commit 2d60d9e

File tree

4 files changed

+148
-13
lines changed

4 files changed

+148
-13
lines changed

SwiftRustIntegrationTestRunner/SwiftRustIntegrationTestRunnerTests/AsyncTests.swift

+26-4
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ class AsyncTests: XCTestCase {
8787
XCTFail()
8888
} catch let error as AsyncResultErrEnum {
8989
switch error {
90-
case .NoFields:
91-
XCTFail()
9290
case .UnnamedFields(_, _):
9391
XCTFail()
9492
case .NamedFields(let valueUInt32):
@@ -114,8 +112,6 @@ class AsyncTests: XCTestCase {
114112
let _: AsyncResultOpaqueRustType1 = try await rust_async_func_return_result_opaque_rust_and_transparent_enum(false)
115113
} catch let error as AsyncResultErrEnum {
116114
switch error {
117-
case .NoFields:
118-
XCTFail()
119115
case .UnnamedFields(_, _):
120116
XCTFail()
121117
case .NamedFields(let value):
@@ -155,6 +151,32 @@ class AsyncTests: XCTestCase {
155151

156152
}
157153

154+
/// Verify that we can return a Result<(), TransparentEnum> from async Rust function
155+
func testSwiftCallsRustAsyncFnReturnResultNullTransparentEnum() async throws {
156+
//Should return an Unit type
157+
do {
158+
let _: () = try await rust_async_func_return_result_null_and_transparent_enum(true)
159+
} catch {
160+
XCTFail()
161+
}
162+
163+
//Should throw an AsyncResultErrEnum
164+
do {
165+
let _ = try await rust_async_func_return_result_null_and_transparent_enum(false)
166+
XCTFail()
167+
} catch let error as AsyncResultErrEnum {
168+
switch error {
169+
case .UnnamedFields(let valueString, let valueInt32):
170+
XCTAssertEqual(valueString.toString(), "foo")
171+
XCTAssertEqual(valueInt32, 123)
172+
case .NamedFields(_):
173+
XCTFail()
174+
}
175+
} catch {
176+
XCTFail()
177+
}
178+
}
179+
158180
func testSwiftCallsRustAsyncFnRetStruct() async throws {
159181
let _: AsyncRustFnReturnStruct = await rust_async_return_struct()
160182
}

crates/swift-bridge-ir/src/bridged_type/bridgeable_result.rs

+9-8
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,6 @@ impl BuiltInResult {
184184
)
185185
}
186186
TypePosition::SwiftCallsRustAsyncOnCompleteReturnTy => {
187-
if self.ok_ty.can_be_encoded_with_zero_bytes() {
188-
todo!()
189-
}
190187
if self.err_ty.can_be_encoded_with_zero_bytes() {
191188
todo!()
192189
}
@@ -417,11 +414,15 @@ typedef struct {c_enum_name}{{{c_tag_name} tag; union {c_fields_name} payload;}}
417414
types: &TypeDeclarations,
418415
) -> String {
419416
if self.is_custom_result_type() {
420-
let ok = self.ok_ty.convert_ffi_expression_to_swift_type(
421-
&format!("{expression}.payload.ok"),
422-
type_pos,
423-
types,
424-
);
417+
let ok = if self.ok_ty.can_be_encoded_with_zero_bytes() {
418+
"()".to_string()
419+
} else {
420+
self.ok_ty.convert_ffi_expression_to_swift_type(
421+
&format!("{expression}.payload.ok"),
422+
type_pos,
423+
types,
424+
)
425+
};
425426
let err = self.err_ty.convert_ffi_expression_to_swift_type(
426427
&format!("{expression}.payload.err"),
427428
type_pos,

crates/swift-bridge-ir/src/codegen/codegen_tests/async_function_codegen_tests.rs

+97
Original file line numberDiff line numberDiff line change
@@ -852,3 +852,100 @@ void __swift_bridge__$some_function(void* callback_wrapper, void __swift_bridge_
852852
.test();
853853
}
854854
}
855+
856+
/// Verify that we generate the correct code for extern "Rust" async functions that returns a Result<(), TransparentEnum>.
857+
mod extern_rust_async_function_returns_result_null_transparent_enum {
858+
use super::*;
859+
860+
fn bridge_module() -> TokenStream {
861+
quote! {
862+
#[swift_bridge::bridge]
863+
mod ffi {
864+
enum ErrEnum {
865+
ErrVariant1,
866+
ErrVariant2,
867+
}
868+
extern "Rust" {
869+
async fn some_function() -> Result<(), ErrEnum>;
870+
}
871+
}
872+
}
873+
}
874+
875+
fn expected_rust_tokens() -> ExpectedRustTokens {
876+
ExpectedRustTokens::Contains(quote! {
877+
pub extern "C" fn __swift_bridge__some_function(
878+
callback_wrapper: *mut std::ffi::c_void,
879+
callback: extern "C" fn(*mut std::ffi::c_void, ResultVoidAndErrEnum) -> (),
880+
) {
881+
let callback_wrapper = swift_bridge::async_support::SwiftCallbackWrapper(callback_wrapper);
882+
let fut = super::some_function();
883+
let task = async move {
884+
let val = match fut.await {
885+
Ok(ok) => ResultVoidAndErrEnum::Ok,
886+
Err(err) => ResultVoidAndErrEnum::Err(err.into_ffi_repr()),
887+
};
888+
let callback_wrapper = callback_wrapper;
889+
let callback_wrapper = callback_wrapper.0;
890+
891+
(callback)(callback_wrapper, val)
892+
};
893+
swift_bridge::async_support::ASYNC_RUNTIME.spawn_task(Box::pin(task))
894+
}
895+
})
896+
}
897+
898+
// TODO: Replace `Error` with the concrete error type `ErrorType`.
899+
// As of Feb 2023 using the concrete error type leads to a compile time error.
900+
// This seems like a bug in the Swift compiler.
901+
902+
fn expected_swift_code() -> ExpectedSwiftCode {
903+
ExpectedSwiftCode::ContainsAfterTrim(
904+
r#"
905+
public func some_function() async throws -> () {
906+
func onComplete(cbWrapperPtr: UnsafeMutableRawPointer?, rustFnRetVal: __swift_bridge__$ResultVoidAndErrEnum) {
907+
let wrapper = Unmanaged<CbWrapper$some_function>.fromOpaque(cbWrapperPtr!).takeRetainedValue()
908+
switch rustFnRetVal.tag { case __swift_bridge__$ResultVoidAndErrEnum$ResultOk: wrapper.cb(.success(())) case __swift_bridge__$ResultVoidAndErrEnum$ResultErr: wrapper.cb(.failure(rustFnRetVal.payload.err.intoSwiftRepr())) default: fatalError() }
909+
}
910+
911+
return try await withCheckedThrowingContinuation({ (continuation: CheckedContinuation<(), Error>) in
912+
let callback = { rustFnRetVal in
913+
continuation.resume(with: rustFnRetVal)
914+
}
915+
916+
let wrapper = CbWrapper$some_function(cb: callback)
917+
let wrapperPtr = Unmanaged.passRetained(wrapper).toOpaque()
918+
919+
__swift_bridge__$some_function(wrapperPtr, onComplete)
920+
})
921+
}
922+
class CbWrapper$some_function {
923+
var cb: (Result<(), Error>) -> ()
924+
925+
public init(cb: @escaping (Result<(), Error>) -> ()) {
926+
self.cb = cb
927+
}
928+
}
929+
"#,
930+
)
931+
}
932+
933+
fn expected_c_header() -> ExpectedCHeader {
934+
ExpectedCHeader::ContainsAfterTrim(
935+
r#"
936+
void __swift_bridge__$some_function(void* callback_wrapper, void __swift_bridge__$some_function$async(void* callback_wrapper, struct __swift_bridge__$ResultVoidAndErrEnum ret));
937+
"#,
938+
)
939+
}
940+
941+
#[test]
942+
fn extern_rust_async_function_returns_result_null_transparent_enum() {
943+
CodegenTest {
944+
bridge_module: bridge_module().into(),
945+
expected_rust_tokens: expected_rust_tokens(),
946+
expected_swift_code: expected_swift_code(),
947+
expected_c_header: expected_c_header(),
948+
}
949+
.test();
950+
}
951+
}

crates/swift-integration-tests/src/async_function.rs

+16-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ mod ffi {
4646
}
4747

4848
enum AsyncResultErrEnum {
49-
NoFields,
5049
UnnamedFields(String, i32),
5150
NamedFields { value: u32 },
5251
}
@@ -61,6 +60,9 @@ mod ffi {
6160
async fn rust_async_func_return_result_transparent_enum_and_opaque_rust(
6261
succeed: bool,
6362
) -> Result<AsyncResultOkEnum, AsyncResultOpaqueRustType1>;
63+
async fn rust_async_func_return_result_null_and_transparent_enum(
64+
succeed: bool,
65+
) -> Result<(), AsyncResultErrEnum>;
6466
}
6567
}
6668

@@ -161,3 +163,16 @@ async fn rust_async_func_return_result_transparent_enum_and_opaque_rust(
161163
Err(AsyncResultOpaqueRustType1(1000))
162164
}
163165
}
166+
167+
async fn rust_async_func_return_result_null_and_transparent_enum(
168+
succeed: bool,
169+
) -> Result<(), ffi::AsyncResultErrEnum> {
170+
if succeed {
171+
Ok(())
172+
} else {
173+
Err(ffi::AsyncResultErrEnum::UnnamedFields(
174+
"foo".to_string(),
175+
123,
176+
))
177+
}
178+
}

0 commit comments

Comments
 (0)