diff --git a/ktor-plugins/ktor-client-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/client/RSocketSupport.kt b/ktor-plugins/ktor-client-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/client/RSocketSupport.kt index f81e8bca..c780c326 100644 --- a/ktor-plugins/ktor-client-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/client/RSocketSupport.kt +++ b/ktor-plugins/ktor-client-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/client/RSocketSupport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,6 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.ktor.websocket.internal.* -import kotlinx.coroutines.* import kotlin.coroutines.* private val RSocketSupportConfigKey = AttributeKey("RSocketSupportConfig") @@ -66,9 +65,7 @@ private class RSocketSupportTarget( override val coroutineContext: CoroutineContext get() = client.coroutineContext @RSocketTransportApi - override fun connectClient(handler: RSocketConnectionHandler): Job = launch { - client.webSocket(request) { - handler.handleKtorWebSocketConnection(this) - } + override suspend fun connectClient(): RSocketConnection { + return KtorWebSocketConnection(client.webSocketSession(request)) } } diff --git a/ktor-plugins/ktor-server-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/server/RSocketSupport.kt b/ktor-plugins/ktor-server-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/server/RSocketSupport.kt index 5346e7da..54a9c127 100644 --- a/ktor-plugins/ktor-server-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/server/RSocketSupport.kt +++ b/ktor-plugins/ktor-server-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/server/RSocketSupport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.ktor.websocket.internal.* +import kotlinx.coroutines.* private val RSocketSupportConfigKey = AttributeKey("RSocketSupportConfig") @@ -54,8 +55,8 @@ internal fun Route.rSocketHandler(acceptor: ConnectionAcceptor): suspend Default val config = application.attributes.getOrNull(RSocketSupportConfigKey) ?: error("Plugin RSocketSupport is not installed. Consider using `install(RSocketSupport)` in server config first.") - val handler = config.server.createHandler(acceptor) return { - handler.handleKtorWebSocketConnection(this) + config.server.acceptConnection(acceptor, KtorWebSocketConnection(this)) + awaitCancellation() } } diff --git a/rsocket-core/api/rsocket-core.api b/rsocket-core/api/rsocket-core.api index 6fd3917d..bcce5398 100644 --- a/rsocket-core/api/rsocket-core.api +++ b/rsocket-core/api/rsocket-core.api @@ -222,9 +222,9 @@ public final class io/rsocket/kotlin/core/RSocketConnectorBuilderKt { } public final class io/rsocket/kotlin/core/RSocketServer { + public final fun acceptConnection (Lio/rsocket/kotlin/ConnectionAcceptor;Lio/rsocket/kotlin/transport/RSocketConnection;)V public final fun bind (Lio/rsocket/kotlin/transport/ServerTransport;Lio/rsocket/kotlin/ConnectionAcceptor;)Ljava/lang/Object; public final fun bindIn (Lkotlinx/coroutines/CoroutineScope;Lio/rsocket/kotlin/transport/ServerTransport;Lio/rsocket/kotlin/ConnectionAcceptor;)Ljava/lang/Object; - public final fun createHandler (Lio/rsocket/kotlin/ConnectionAcceptor;)Lio/rsocket/kotlin/transport/RSocketConnectionHandler; public final fun startServer (Lio/rsocket/kotlin/transport/RSocketServerTarget;Lio/rsocket/kotlin/ConnectionAcceptor;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } @@ -751,14 +751,10 @@ public final class io/rsocket/kotlin/transport/ClientTransportKt { } public abstract interface class io/rsocket/kotlin/transport/RSocketClientTarget : kotlinx/coroutines/CoroutineScope { - public abstract fun connectClient (Lio/rsocket/kotlin/transport/RSocketConnectionHandler;)Lkotlinx/coroutines/Job; + public abstract fun connectClient (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } -public abstract interface class io/rsocket/kotlin/transport/RSocketConnection { -} - -public abstract interface class io/rsocket/kotlin/transport/RSocketConnectionHandler { - public abstract fun handleConnection (Lio/rsocket/kotlin/transport/RSocketConnection;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +public abstract interface class io/rsocket/kotlin/transport/RSocketConnection : kotlinx/coroutines/CoroutineScope { } public abstract interface class io/rsocket/kotlin/transport/RSocketMultiplexedConnection : io/rsocket/kotlin/transport/RSocketConnection { @@ -766,16 +762,13 @@ public abstract interface class io/rsocket/kotlin/transport/RSocketMultiplexedCo public abstract fun createStream (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } -public abstract interface class io/rsocket/kotlin/transport/RSocketMultiplexedConnection$Stream : java/lang/AutoCloseable { - public abstract fun close ()V - public abstract fun isClosedForSend ()Z +public abstract interface class io/rsocket/kotlin/transport/RSocketMultiplexedConnection$Stream : kotlinx/coroutines/CoroutineScope { public abstract fun receiveFrame (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun sendFrame (Lkotlinx/io/Buffer;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun setSendPriority (I)V } public abstract interface class io/rsocket/kotlin/transport/RSocketSequentialConnection : io/rsocket/kotlin/transport/RSocketConnection { - public abstract fun isClosedForSend ()Z public abstract fun receiveFrame (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun sendFrame (ILkotlinx/io/Buffer;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } @@ -784,7 +777,7 @@ public abstract interface class io/rsocket/kotlin/transport/RSocketServerInstanc } public abstract interface class io/rsocket/kotlin/transport/RSocketServerTarget : kotlinx/coroutines/CoroutineScope { - public abstract fun startServer (Lio/rsocket/kotlin/transport/RSocketConnectionHandler;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun startServer (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public abstract interface class io/rsocket/kotlin/transport/RSocketTransport : kotlinx/coroutines/CoroutineScope { @@ -809,7 +802,7 @@ public abstract interface class io/rsocket/kotlin/transport/ServerTransport { } public final class io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue { - public fun (I)V + public fun ()V public final fun cancel ()V public final fun close ()V public final fun dequeueFrame (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentContext.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentContext.kt index 33b739d7..826e3d18 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentContext.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentContext.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,23 +20,16 @@ import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.frame.io.* import io.rsocket.kotlin.keepalive.* import io.rsocket.kotlin.payload.* -import io.rsocket.kotlin.transport.* import kotlinx.io.* // send/receive setup, resume, resume ok, lease, error -@RSocketTransportApi internal abstract class ConnectionEstablishmentContext( - private val frameCodec: FrameCodec, + protected val frameCodec: FrameCodec, ) { - protected abstract suspend fun receiveFrameRaw(): Buffer? - protected abstract suspend fun sendFrame(frame: Buffer) - private suspend fun sendFrame(frame: Frame): Unit = sendFrame(frameCodec.encodeFrame(frame)) + protected abstract suspend fun receiveConnectionFrameRaw(): Buffer? + protected abstract suspend fun sendConnectionFrameRaw(frame: Buffer) - // only setup|lease|resume|resume_ok|error frames - suspend fun receiveFrame(): Frame = frameCodec.decodeFrame( - expectedStreamId = 0, - frame = receiveFrameRaw() ?: error("Expected frame during connection establishment but nothing was received") - ) + protected suspend fun sendFrameConnectionFrame(frame: Frame): Unit = sendConnectionFrameRaw(frameCodec.encodeFrame(frame)) suspend fun sendSetup( version: Version, @@ -45,5 +38,11 @@ internal abstract class ConnectionEstablishmentContext( resumeToken: Buffer?, payloadMimeType: PayloadMimeType, payload: Payload, - ): Unit = sendFrame(SetupFrame(version, honorLease, keepAlive, resumeToken, payloadMimeType, payload)) + ): Unit = sendFrameConnectionFrame(SetupFrame(version, honorLease, keepAlive, resumeToken, payloadMimeType, payload)) + + // only setup|lease|resume|resume_ok|error frames + suspend fun receiveFrame(): Frame = frameCodec.decodeFrame( + expectedStreamId = 0, + frame = receiveConnectionFrameRaw() ?: error("Expected frame during connection establishment but nothing was received") + ) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentHandler.kt deleted file mode 100644 index 532fdb19..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentHandler.kt +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.connection - -import io.rsocket.kotlin.* -import io.rsocket.kotlin.core.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.keepalive.* -import io.rsocket.kotlin.transport.* -import kotlinx.coroutines.* -import kotlin.coroutines.* - -@RSocketTransportApi -internal abstract class ConnectionEstablishmentHandler( - private val isClient: Boolean, - private val frameCodec: FrameCodec, - private val connectionAcceptor: ConnectionAcceptor, - private val interceptors: Interceptors, - private val requesterDeferred: CompletableDeferred?, -) : RSocketConnectionHandler { - abstract suspend fun establishConnection(context: ConnectionEstablishmentContext): ConnectionConfig - - private suspend fun wrapConnection( - connection: RSocketConnection, - requestContext: CoroutineContext, - ): Connection2 = when (connection) { - is RSocketMultiplexedConnection -> { - val initialStream = when { - isClient -> connection.createStream() - else -> connection.acceptStream() ?: error("Initial stream should be received") - } - initialStream.setSendPriority(0) - MultiplexedConnection(isClient, frameCodec, requestContext, connection, initialStream) - } - - is RSocketSequentialConnection -> { - SequentialConnection(isClient, frameCodec, requestContext, connection) - } - } - - @Suppress("SuspendFunctionOnCoroutineScope") - private suspend fun CoroutineScope.handleConnection(connection: Connection2) { - try { - val connectionConfig = connection.establishConnection(this@ConnectionEstablishmentHandler) - try { - val requester = interceptors.wrapRequester(connection) - val responder = interceptors.wrapResponder( - with(interceptors.wrapAcceptor(connectionAcceptor)) { - ConnectionAcceptorContext(connectionConfig, requester).accept() - } - ) - - // link completing of requester, connection and requestHandler - requester.coroutineContext.job.invokeOnCompletion { - coroutineContext.job.cancel("Requester cancelled", it) - } - responder.coroutineContext.job.invokeOnCompletion { - coroutineContext.job.cancel("Responder cancelled", it) - } - coroutineContext.job.invokeOnCompletion { cause -> - // the responder is not linked to `coroutineContext` - responder.cancel("Connection closed", cause) - } - - requesterDeferred?.complete(requester) - - val keepAliveHandler = KeepAliveHandler(connectionConfig.keepAlive, connection, this) - connection.handleConnection( - ConnectionInbound(connection.coroutineContext, responder, keepAliveHandler) - ) - } catch (cause: Throwable) { - connectionConfig.setupPayload.close() - throw cause - } - } catch (cause: Throwable) { - connection.close() - withContext(NonCancellable) { - connection.sendError( - when (cause) { - is RSocketError -> cause - else -> RSocketError.ConnectionError(cause.message ?: "Connection failed") - } - ) - } - throw cause - } - } - - final override suspend fun handleConnection(connection: RSocketConnection): Unit = coroutineScope { - handleConnection(wrapConnection(connection, coroutineContext.supervisorContext())) - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInbound.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInbound.kt index e023848d..aae8fc25 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInbound.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInbound.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,20 +18,15 @@ package io.rsocket.kotlin.connection import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.keepalive.* import io.rsocket.kotlin.operation.* -import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlinx.io.* -import kotlin.coroutines.* -@RSocketTransportApi internal class ConnectionInbound( - // requestContext - override val coroutineContext: CoroutineContext, + private val requestsScope: CoroutineScope, private val responder: RSocket, private val keepAliveHandler: KeepAliveHandler, -) : CoroutineScope { +) { fun handleFrame(frame: Frame): Unit = when (frame) { is MetadataPushFrame -> receiveMetadataPush(frame.metadata) is KeepAliveFrame -> receiveKeepAlive(frame.respond, frame.data, frame.lastPosition) @@ -42,9 +37,9 @@ internal class ConnectionInbound( } private fun receiveMetadataPush(metadata: Buffer) { - launch { + requestsScope.launch { responder.metadataPush(metadata) - }.invokeOnCompletion { metadata.close() } + }.invokeOnCompletion { metadata.clear() } } @Suppress("UNUSED_PARAMETER") // will be used later diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInitializer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInitializer.kt new file mode 100644 index 00000000..f1ef8aac --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInitializer.kt @@ -0,0 +1,137 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.core.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* + +@RSocketTransportApi +internal abstract class ConnectionInitializer( + private val isClient: Boolean, + private val frameCodec: FrameCodec, + private val connectionAcceptor: ConnectionAcceptor, + private val interceptors: Interceptors, +) { + protected abstract suspend fun establishConnection(context: ConnectionEstablishmentContext): ConnectionConfig + + private suspend fun wrapConnection( + connection: RSocketConnection, + requestsScope: CoroutineScope, + ): ConnectionOutbound = when (connection) { + is RSocketMultiplexedConnection -> { + val initialStream = when { + isClient -> connection.createStream() + else -> connection.acceptStream() ?: error("Initial stream should be received") + } + initialStream.setSendPriority(0) + MultiplexedConnection(isClient, frameCodec, connection, initialStream, requestsScope) + } + + is RSocketSequentialConnection -> { + SequentialConnection(isClient, frameCodec, connection, requestsScope) + } + } + + private suspend fun initialize(connection: RSocketConnection): RSocket { + val requestsScope = CoroutineScope(connection.coroutineContext.supervisorContext()) + val outbound = wrapConnection(connection, requestsScope) + val connectionJob = connection.launch(start = CoroutineStart.ATOMIC) { + try { + awaitCancellation() + } catch (cause: Throwable) { + if (connection.isActive) { + nonCancellable { + outbound.sendError(RSocketError.ConnectionError(cause.message ?: "Connection failed")) + } + connection.cancel("Connection failed", cause) + } + throw cause + } + } + val connectionScope = CoroutineScope(connection.coroutineContext + connectionJob) + try { + val connectionConfig = establishConnection(outbound) + try { + val requester = interceptors.wrapRequester( + RequesterRSocket(requestsScope, outbound) + ) + val responder = interceptors.wrapResponder( + with(interceptors.wrapAcceptor(connectionAcceptor)) { + ConnectionAcceptorContext(connectionConfig, requester).accept() + } + ) + + // link completing of requester, connection and requestHandler + requester.coroutineContext.job.invokeOnCompletion { + connectionJob.cancel("Requester cancelled", it) + } + responder.coroutineContext.job.invokeOnCompletion { + connectionJob.cancel("Responder cancelled", it) + } + connectionJob.invokeOnCompletion { cause -> + // the responder is not linked to `coroutineContext` + responder.cancel("Connection closed", cause) + } + + val keepAliveHandler = KeepAliveHandler(connectionConfig.keepAlive, outbound, connectionScope) + connectionScope.launch { + outbound.handleConnection(ConnectionInbound(requestsScope, responder, keepAliveHandler)) + } + return requester + } catch (cause: Throwable) { + connectionConfig.setupPayload.close() + throw cause + } + } catch (cause: Throwable) { + nonCancellable { + outbound.sendError( + when (cause) { + is RSocketError -> cause + else -> RSocketError.ConnectionError(cause.message ?: "Connection establishment failed") + } + ) + } + throw cause + } + } + + private fun asyncInitializer(connection: RSocketConnection): Deferred = connection.async { + try { + initialize(connection) + } catch (cause: Throwable) { + connection.cancel("Connection initialization failed", cause) + throw cause + } + } + + suspend fun runInitializer(connection: RSocketConnection): RSocket { + val result = asyncInitializer(connection) + try { + result.join() + } catch (cause: Throwable) { + connection.cancel("Connection initialization cancelled", cause) + throw cause + } + return result.await() + } + + fun launchInitializer(connection: RSocketConnection): Job = asyncInitializer(connection) +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionOutbound.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionOutbound.kt new file mode 100644 index 00000000..b4ac6b82 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionOutbound.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.operation.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* +import kotlinx.io.* + +internal abstract class ConnectionOutbound( + frameCodec: FrameCodec, +) : ConnectionEstablishmentContext(frameCodec) { + suspend fun sendError(cause: Throwable) { + sendFrameConnectionFrame(ErrorFrame(0, cause)) + } + + suspend fun sendMetadataPush(metadata: Buffer) { + sendFrameConnectionFrame(MetadataPushFrame(metadata)) + } + + suspend fun sendKeepAlive(respond: Boolean, data: Buffer, lastPosition: Long) { + sendFrameConnectionFrame(KeepAliveFrame(respond, lastPosition, data)) + } + + abstract suspend fun handleConnection(inbound: ConnectionInbound) + + abstract fun launchRequest(requestPayload: Payload, operation: RequesterOperation): Job +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/keepalive/KeepAliveHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/KeepAliveHandler.kt similarity index 83% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/keepalive/KeepAliveHandler.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/KeepAliveHandler.kt index 58966841..13319e83 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/keepalive/KeepAliveHandler.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/KeepAliveHandler.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,21 +14,19 @@ * limitations under the License. */ -package io.rsocket.kotlin.keepalive +package io.rsocket.kotlin.connection import io.rsocket.kotlin.* -import io.rsocket.kotlin.connection.* import io.rsocket.kotlin.frame.io.* -import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.keepalive.* import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.io.* import kotlin.time.* -@RSocketTransportApi internal class KeepAliveHandler( private val keepAlive: KeepAlive, - private val connection2: Connection2, + private val outbound: ConnectionOutbound, private val connectionScope: CoroutineScope, ) { private val initial = TimeSource.Monotonic.markNow() @@ -44,7 +42,7 @@ internal class KeepAliveHandler( if (currentDelayMillis() - lastMark.value >= keepAlive.maxLifetimeMillis) throw RSocketError.ConnectionError("No keep-alive for ${keepAlive.maxLifetimeMillis} ms") - connection2.sendKeepAlive(true, EmptyBuffer, 0) + outbound.sendKeepAlive(true, EmptyBuffer, 0) } } } @@ -53,7 +51,7 @@ internal class KeepAliveHandler( lastMark.value = currentDelayMillis() // in most cases it will be possible to not suspend at all if (respond) connectionScope.launch(start = CoroutineStart.UNDISPATCHED) { - connection2.sendKeepAlive(false, data, 0) + outbound.sendKeepAlive(false, data, 0) } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt index e3294ed8..3d0ac397 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,19 +21,16 @@ import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.transport.* import kotlinx.io.* +import kotlin.coroutines.* @RSocketLoggingApi @RSocketTransportApi -internal fun RSocketConnectionHandler.logging(logger: Logger): RSocketConnectionHandler { +internal fun RSocketConnection.logging(logger: Logger): RSocketConnection { if (!logger.isLoggable(LoggingLevel.DEBUG)) return this - return RSocketConnectionHandler { - handleConnection( - when (it) { - is RSocketSequentialConnection -> SequentialLoggingConnection(it, logger) - is RSocketMultiplexedConnection -> MultiplexedLoggingConnection(it, logger) - } - ) + return when (this) { + is RSocketSequentialConnection -> SequentialLoggingConnection(this, logger) + is RSocketMultiplexedConnection -> MultiplexedLoggingConnection(this, logger) } } @@ -43,7 +40,7 @@ private class SequentialLoggingConnection( private val delegate: RSocketSequentialConnection, private val logger: Logger, ) : RSocketSequentialConnection { - override val isClosedForSend: Boolean get() = delegate.isClosedForSend + override val coroutineContext: CoroutineContext get() = delegate.coroutineContext override suspend fun sendFrame(streamId: Int, frame: Buffer) { logger.debug { "Send: ${dumpFrameToString(frame)}" } @@ -69,6 +66,8 @@ private class MultiplexedLoggingConnection( private val delegate: RSocketMultiplexedConnection, private val logger: Logger, ) : RSocketMultiplexedConnection { + override val coroutineContext: CoroutineContext get() = delegate.coroutineContext + override suspend fun createStream(): RSocketMultiplexedConnection.Stream { return MultiplexedLoggingStream(delegate.createStream(), logger) } @@ -86,7 +85,7 @@ private class MultiplexedLoggingStream( private val delegate: RSocketMultiplexedConnection.Stream, private val logger: Logger, ) : RSocketMultiplexedConnection.Stream { - override val isClosedForSend: Boolean get() = delegate.isClosedForSend + override val coroutineContext: CoroutineContext get() = delegate.coroutineContext override fun setSendPriority(priority: Int) { delegate.setSendPriority(priority) @@ -102,8 +101,4 @@ private class MultiplexedLoggingStream( logger.debug { "Receive: ${dumpFrameToString(frame)}" } } } - - override fun close() { - delegate.close() - } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/MultiplexedConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/MultiplexedConnection.kt index 11ac0b90..430cf5d0 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/MultiplexedConnection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/MultiplexedConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package io.rsocket.kotlin.connection -import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.operation.* @@ -24,32 +23,27 @@ import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlinx.io.* -import kotlin.coroutines.* @RSocketTransportApi internal class MultiplexedConnection( isClient: Boolean, frameCodec: FrameCodec, - requestContext: CoroutineContext, private val connection: RSocketMultiplexedConnection, private val initialStream: RSocketMultiplexedConnection.Stream, -) : Connection2(frameCodec, requestContext) { + private val requestsScope: CoroutineScope, +) : ConnectionOutbound(frameCodec) { private val storage = StreamDataStorage(isClient) - override fun close() { - storage.clear() - } - - override suspend fun establishConnection(handler: ConnectionEstablishmentHandler): ConnectionConfig { - return handler.establishConnection(EstablishmentContext()) + init { + connection.coroutineContext.job.invokeOnCompletion { + storage.clear() + } } - private inner class EstablishmentContext : ConnectionEstablishmentContext(frameCodec) { - override suspend fun sendFrame(frame: Buffer): Unit = initialStream.sendFrame(frame) - override suspend fun receiveFrameRaw(): Buffer? = initialStream.receiveFrame() - } + override suspend fun sendConnectionFrameRaw(frame: Buffer): Unit = initialStream.sendFrame(frame) + override suspend fun receiveConnectionFrameRaw(): Buffer? = initialStream.receiveFrame() - override suspend fun handleConnection(inbound: ConnectionInbound) = coroutineScope { + override suspend fun handleConnection(inbound: ConnectionInbound): Unit = coroutineScope { launch { while (true) { val frame = frameCodec.decodeFrame( @@ -63,15 +57,11 @@ internal class MultiplexedConnection( while (true) if (!acceptRequest(inbound)) break } - override suspend fun sendConnectionFrame(frame: Buffer) { - initialStream.sendFrame(frame) - } - @OptIn(DelicateCoroutinesApi::class) override fun launchRequest( requestPayload: Payload, operation: RequesterOperation, - ): Job = launch(start = CoroutineStart.ATOMIC) { + ): Job = requestsScope.launch(start = CoroutineStart.ATOMIC) { operation.handleExecutionFailure(requestPayload) { ensureActive() // because of atomic start val stream = connection.createStream() @@ -80,7 +70,7 @@ internal class MultiplexedConnection( execute(streamId, stream, requestPayload, operation) } finally { storage.removeStream(streamId) - stream.close() + stream.cancel("Stream closed") } } } @@ -89,7 +79,7 @@ internal class MultiplexedConnection( private fun acceptRequest( connectionInbound: ConnectionInbound, stream: RSocketMultiplexedConnection.Stream, - ): Job = launch(start = CoroutineStart.ATOMIC) { + ): Job = requestsScope.launch(start = CoroutineStart.ATOMIC) { try { ensureActive() // because of atomic start val ( @@ -112,7 +102,7 @@ internal class MultiplexedConnection( storage.removeStream(streamId) } } finally { - stream.close() + stream.cancel("Stream closed") } } @@ -226,7 +216,6 @@ internal class MultiplexedConnection( streamId: Int, private val stream: RSocketMultiplexedConnection.Stream, ) : OperationOutbound(streamId, frameCodec) { - override val isClosed: Boolean get() = stream.isClosedForSend override suspend fun sendFrame(frame: Buffer): Unit = stream.sendFrame(frame) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/OldConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/OldConnection.kt index ee3652d5..49fdabd2 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/OldConnection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/OldConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,35 +21,38 @@ import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.internal.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* import kotlinx.io.* +import kotlin.coroutines.* @Suppress("DEPRECATION_ERROR") @RSocketTransportApi -internal suspend fun RSocketConnectionHandler.handleConnection(connection: Connection): Unit = coroutineScope { - val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) +internal class OldConnection( + private val connection: Connection, +) : RSocketSequentialConnection { + private val outboundQueue = PrioritizationFrameQueue() + + override val coroutineContext: CoroutineContext get() = connection.coroutineContext - val senderJob = launch { - while (true) connection.send(outboundQueue.dequeueFrame() ?: break) - }.onCompletion { outboundQueue.cancel() } + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + launch { + nonCancellable { + while (true) { + connection.send(outboundQueue.dequeueFrame() ?: break) + } + } + }.onCompletion { + outboundQueue.cancel() + } - try { - handleConnection(OldConnection(outboundQueue, connection)) - } finally { - outboundQueue.close() - withContext(NonCancellable) { - senderJob.join() + try { + awaitCancellation() + } finally { + outboundQueue.close() + } } } -} - -@Suppress("DEPRECATION_ERROR") -@RSocketTransportApi -private class OldConnection( - private val outboundQueue: PrioritizationFrameQueue, - private val connection: Connection, -) : RSocketSequentialConnection { - override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend override suspend fun sendFrame(streamId: Int, frame: Buffer) { return outboundQueue.enqueueFrame(streamId, frame) @@ -57,7 +60,7 @@ private class OldConnection( override suspend fun receiveFrame(): Buffer? = try { connection.receive() - } catch (cause: Throwable) { + } catch (_: Throwable) { currentCoroutineContext().ensureActive() null } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/RequesterRSocket.kt similarity index 54% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/RequesterRSocket.kt index bc262862..2a818130 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/RequesterRSocket.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,67 +17,30 @@ package io.rsocket.kotlin.connection import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.operation.* import io.rsocket.kotlin.payload.* -import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlinx.io.* import kotlin.coroutines.* -// TODO: rename to just `Connection` after root `Connection` will be dropped -@RSocketTransportApi -internal abstract class Connection2( - protected val frameCodec: FrameCodec, - // requestContext - final override val coroutineContext: CoroutineContext, -) : RSocket, AutoCloseable { +internal class RequesterRSocket( + private val requestsScope: CoroutineScope, + private val outbound: ConnectionOutbound, +) : RSocket { + override val coroutineContext: CoroutineContext get() = requestsScope.coroutineContext - // connection establishment part - - abstract suspend fun establishConnection(handler: ConnectionEstablishmentHandler): ConnectionConfig - - // setup completed, start handling requests - abstract suspend fun handleConnection(inbound: ConnectionInbound) - - // connection part - - protected abstract suspend fun sendConnectionFrame(frame: Buffer) - private suspend fun sendConnectionFrame(frame: Frame): Unit = sendConnectionFrame(frameCodec.encodeFrame(frame)) - - suspend fun sendError(cause: Throwable) { - sendConnectionFrame(ErrorFrame(0, cause)) + override suspend fun metadataPush(metadata: Buffer) { + ensureActiveOrClose(metadata::clear) + outbound.sendMetadataPush(metadata) } - private suspend fun sendMetadataPush(metadata: Buffer) { - sendConnectionFrame(MetadataPushFrame(metadata)) - } - - suspend fun sendKeepAlive(respond: Boolean, data: Buffer, lastPosition: Long) { - sendConnectionFrame(KeepAliveFrame(respond, lastPosition, data)) - } - - // operations part - - protected abstract fun launchRequest(requestPayload: Payload, operation: RequesterOperation): Job - private suspend fun ensureActiveOrClose(closeable: AutoCloseable) { - currentCoroutineContext().ensureActive { closeable.close() } - coroutineContext.ensureActive { closeable.close() } - } - - final override suspend fun metadataPush(metadata: Buffer) { - ensureActiveOrClose(metadata) - sendMetadataPush(metadata) - } - - final override suspend fun fireAndForget(payload: Payload) { - ensureActiveOrClose(payload) + override suspend fun fireAndForget(payload: Payload) { + ensureActiveOrClose(payload::close) suspendCancellableCoroutine { cont -> - val requestJob = launchRequest( + val requestJob = outbound.launchRequest( requestPayload = payload, operation = RequesterFireAndForgetOperation(cont) ) @@ -87,12 +50,12 @@ internal abstract class Connection2( } } - final override suspend fun requestResponse(payload: Payload): Payload { - ensureActiveOrClose(payload) + override suspend fun requestResponse(payload: Payload): Payload { + ensureActiveOrClose(payload::close) val responseDeferred = CompletableDeferred() - val requestJob = launchRequest( + val requestJob = outbound.launchRequest( requestPayload = payload, operation = RequesterRequestResponseOperation(responseDeferred) ) @@ -107,14 +70,14 @@ internal abstract class Connection2( } @OptIn(ExperimentalStreamsApi::class) - final override fun requestStream( + override fun requestStream( payload: Payload, ): Flow = payloadFlow { strategy, initialRequest -> - ensureActiveOrClose(payload) + ensureActiveOrClose(payload::close) val responsePayloads = PayloadChannel() - val requestJob = launchRequest( + val requestJob = outbound.launchRequest( requestPayload = payload, operation = RequesterRequestStreamOperation(initialRequest, responsePayloads) ) @@ -128,15 +91,15 @@ internal abstract class Connection2( } @OptIn(ExperimentalStreamsApi::class) - final override fun requestChannel( + override fun requestChannel( initPayload: Payload, payloads: Flow, ): Flow = payloadFlow { strategy, initialRequest -> - ensureActiveOrClose(initPayload) + ensureActiveOrClose(initPayload::close) val responsePayloads = PayloadChannel() - val requestJob = launchRequest( + val requestJob = outbound.launchRequest( initPayload, RequesterRequestChannelOperation(initialRequest, payloads, responsePayloads) ) @@ -148,4 +111,16 @@ internal abstract class Connection2( throw cause } ?: return@payloadFlow } + + private suspend inline fun ensureActiveOrClose(onInactive: () -> Unit) { + currentCoroutineContext().ensureActive(onInactive) + coroutineContext.ensureActive(onInactive) + } + + private inline fun CoroutineContext.ensureActive(onInactive: () -> Unit) { + if (isActive) return + onInactive() // should not throw + ensureActive() // will throw + } + } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/SequentialConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/SequentialConnection.kt index e3739842..fe6fec8a 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/SequentialConnection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/SequentialConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,36 +16,30 @@ package io.rsocket.kotlin.connection -import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.operation.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlinx.io.* -import kotlin.coroutines.* @RSocketTransportApi internal class SequentialConnection( isClient: Boolean, frameCodec: FrameCodec, - requestContext: CoroutineContext, private val connection: RSocketSequentialConnection, -) : Connection2(frameCodec, requestContext) { + private val requestsScope: CoroutineScope, +) : ConnectionOutbound(frameCodec) { private val storage = StreamDataStorage(isClient) - override fun close() { - storage.clear().forEach { it.close() } - } - - override suspend fun establishConnection(handler: ConnectionEstablishmentHandler): ConnectionConfig { - return handler.establishConnection(EstablishmentContext()) + init { + connection.coroutineContext.job.invokeOnCompletion { + storage.clear().forEach { it.close() } + } } - private inner class EstablishmentContext : ConnectionEstablishmentContext(frameCodec) { - override suspend fun sendFrame(frame: Buffer): Unit = connection.sendFrame(streamId = 0, frame) - override suspend fun receiveFrameRaw(): Buffer? = connection.receiveFrame() - } + override suspend fun sendConnectionFrameRaw(frame: Buffer): Unit = connection.sendFrame(streamId = 0, frame) + override suspend fun receiveConnectionFrameRaw(): Buffer? = connection.receiveFrame() override suspend fun handleConnection(inbound: ConnectionInbound) { while (true) { @@ -59,15 +53,11 @@ internal class SequentialConnection( } } - override suspend fun sendConnectionFrame(frame: Buffer) { - connection.sendFrame(0, frame) - } - @OptIn(DelicateCoroutinesApi::class) override fun launchRequest( requestPayload: Payload, operation: RequesterOperation, - ): Job = launch(start = CoroutineStart.ATOMIC) { + ): Job = requestsScope.launch(start = CoroutineStart.ATOMIC) { operation.handleExecutionFailure(requestPayload) { ensureActive() // because of atomic start val streamId = storage.createStream(OperationFrameHandler(operation)) @@ -84,9 +74,9 @@ internal class SequentialConnection( connectionInbound: ConnectionInbound, operationData: ResponderOperationData, ): ResponderOperation { - val requestJob = Job(coroutineContext.job) + val requestJob = Job(requestsScope.coroutineContext.job) val operation = connectionInbound.createOperation(operationData.requestType, requestJob) - launch(requestJob, start = CoroutineStart.ATOMIC) { + requestsScope.launch(requestJob, start = CoroutineStart.ATOMIC) { val ( streamId, _, @@ -144,7 +134,6 @@ internal class SequentialConnection( } private inner class Outbound(streamId: Int) : OperationOutbound(streamId, frameCodec) { - override val isClosed: Boolean get() = !isActive || connection.isClosedForSend override suspend fun sendFrame(frame: Buffer): Unit = connection.sendFrame(streamId, frame) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt index 61300dd9..b4c52669 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,10 +20,8 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.connection.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.frame.io.* -import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.transport.* -import kotlinx.coroutines.* import kotlin.coroutines.* @OptIn(RSocketTransportApi::class, RSocketLoggingApi::class) @@ -42,8 +40,10 @@ public class RSocketConnector internal constructor( @Deprecated(level = DeprecationLevel.ERROR, message = "Deprecated in favor of new Transport API") public suspend fun connect(transport: ClientTransport): RSocket = connect(object : RSocketClientTarget { override val coroutineContext: CoroutineContext get() = transport.coroutineContext - override fun connectClient(handler: RSocketConnectionHandler): Job = launch { - handler.handleConnection(interceptors.wrapConnection(transport.connect())) + + @RSocketTransportApi + override suspend fun connectClient(): RSocketConnection { + return OldConnection(interceptors.wrapConnection(transport.connect())) } }) @@ -58,24 +58,16 @@ public class RSocketConnector internal constructor( } private suspend fun connectOnce(transport: RSocketClientTarget): RSocket { - val requesterDeferred = CompletableDeferred() - val connectJob = transport.connectClient( - SetupConnection(requesterDeferred).logging(frameLogger) - ).onCompletion { if (it != null) requesterDeferred.completeExceptionally(it) } - return try { - requesterDeferred.await() - } catch (cause: Throwable) { - connectJob.cancel("RSocketConnector.connect was cancelled", cause) - throw cause - } + return SetupConnection().runInitializer( + transport.connectClient().logging(frameLogger) + ) } - private inner class SetupConnection(requesterDeferred: CompletableDeferred) : ConnectionEstablishmentHandler( + private inner class SetupConnection() : ConnectionInitializer( isClient = true, frameCodec = FrameCodec(maxFragmentSize), connectionAcceptor = acceptor, - interceptors = interceptors, - requesterDeferred = requesterDeferred + interceptors = interceptors ) { override suspend fun establishConnection(context: ConnectionEstablishmentContext): ConnectionConfig { val connectionConfig = connectionConfigProvider() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt index 225de22c..bf350992 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,30 +46,28 @@ public class RSocketServer internal constructor( scope: CoroutineScope, transport: ServerTransport, acceptor: ConnectionAcceptor, - ): T { - val handler = createHandler(acceptor) - return with(transport) { - scope.start { - handler.handleConnection(interceptors.wrapConnection(it)) - } + ): T = with(transport) { + scope.start { + acceptConnection(acceptor, OldConnection(interceptors.wrapConnection(it))) + awaitCancellation() } } public suspend fun startServer( transport: RSocketServerTarget, acceptor: ConnectionAcceptor, - ): T = transport.startServer(createHandler(acceptor)) + ): T = transport.startServer { acceptConnection(acceptor, it) } @RSocketTransportApi - public fun createHandler(acceptor: ConnectionAcceptor): RSocketConnectionHandler = - AcceptConnection(acceptor).logging(frameLogger) + public fun acceptConnection(acceptor: ConnectionAcceptor, connection: RSocketConnection) { + AcceptConnection(acceptor).launchInitializer(connection.logging(frameLogger)) + } - private inner class AcceptConnection(acceptor: ConnectionAcceptor) : ConnectionEstablishmentHandler( + private inner class AcceptConnection(acceptor: ConnectionAcceptor) : ConnectionInitializer( isClient = false, frameCodec = FrameCodec(maxFragmentSize), connectionAcceptor = acceptor, - interceptors = interceptors, - requesterDeferred = null + interceptors = interceptors ) { override suspend fun establishConnection(context: ConnectionEstablishmentContext): ConnectionConfig { val setupFrame = context.receiveFrame() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt index c83484e6..48e7070d 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -87,22 +87,11 @@ private class ReconnectableRSocket( suspend fun currentRSocket(): RSocket = state.value.current() ?: state.mapNotNull { it.current() }.first() - private suspend fun currentRSocket(metadata: Buffer): RSocket { - return try { - currentRSocket() - } catch (cause: Throwable) { - metadata.clear() - throw cause - } - } - - private suspend fun currentRSocket(payload: Payload): RSocket { - return try { - currentRSocket() - } catch (cause: Throwable) { - payload.close() - throw cause - } + private suspend inline fun currentRSocket(onFailure: () -> Unit): RSocket = try { + currentRSocket() + } catch (cause: Throwable) { + onFailure() + throw cause } private fun ReconnectState.current(): RSocket? = when (this) { @@ -112,20 +101,20 @@ private class ReconnectableRSocket( } override suspend fun metadataPush(metadata: Buffer): Unit = - currentRSocket(metadata).metadataPush(metadata) + currentRSocket(metadata::clear).metadataPush(metadata) override suspend fun fireAndForget(payload: Payload): Unit = - currentRSocket(payload).fireAndForget(payload) + currentRSocket(payload::close).fireAndForget(payload) override suspend fun requestResponse(payload: Payload): Payload = - currentRSocket(payload).requestResponse(payload) + currentRSocket(payload::close).requestResponse(payload) override fun requestStream(payload: Payload): Flow = flow { - emitAll(currentRSocket(payload).requestStream(payload)) + emitAll(currentRSocket(payload::close).requestStream(payload)) } override fun requestChannel(initPayload: Payload, payloads: Flow): Flow = flow { - emitAll(currentRSocket(initPayload).requestChannel(initPayload, payloads)) + emitAll(currentRSocket(initPayload::close).requestChannel(initPayload, payloads)) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationOutbound.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationOutbound.kt index debd9ff7..f0520234 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationOutbound.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationOutbound.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,8 +34,6 @@ internal abstract class OperationOutbound( // TODO: decide on it // private var firstRequestFrameSent: Boolean = false - abstract val isClosed: Boolean - protected abstract suspend fun sendFrame(frame: Buffer) private suspend fun sendFrame(frame: Frame): Unit = sendFrame(frameCodec.encodeFrame(frame)) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterFireAndForgetOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterFireAndForgetOperation.kt index ad290734..99a99504 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterFireAndForgetOperation.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterFireAndForgetOperation.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package io.rsocket.kotlin.operation import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* import kotlin.coroutines.* @@ -36,7 +37,7 @@ internal class RequesterFireAndForgetOperation( requestSentCont.resume(Unit) } catch (cause: Throwable) { if (requestSentCont.isActive) requestSentCont.resumeWithException(cause) - if (!outbound.isClosed) withContext(NonCancellable) { outbound.sendCancel() } + nonCancellable { outbound.sendCancel() } throw cause } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestChannelOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestChannelOperation.kt index 6dc02b36..c5ee0b9e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestChannelOperation.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestChannelOperation.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package io.rsocket.kotlin.operation import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.payload.* import kotlinx.atomicfu.* import kotlinx.coroutines.* @@ -58,11 +59,11 @@ internal class RequesterRequestChannelOperation( try { while (true) outbound.sendRequestN(responsePayloads.nextRequestN() ?: break) } catch (cause: Throwable) { - if (!currentCoroutineContext().isActive || !outbound.isClosed) throw cause + if (!currentCoroutineContext().isActive) throw cause } } } catch (cause: Throwable) { - if (!outbound.isClosed) withContext(NonCancellable) { + nonCancellable { when (val error = failure) { null -> outbound.sendCancel() else -> outbound.sendError(error) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperation.kt index 3722a9aa..6f2103e1 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperation.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperation.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package io.rsocket.kotlin.operation import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* @@ -35,7 +36,7 @@ internal class RequesterRequestResponseOperation( responseDeferred.join() } catch (cause: Throwable) { // TODO: we don't need to send cancel if we have sent no frames - if (!outbound.isClosed) withContext(NonCancellable) { outbound.sendCancel() } + nonCancellable { outbound.sendCancel() } throw cause } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestStreamOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestStreamOperation.kt index 4e0d559f..37c2a3b3 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestStreamOperation.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestStreamOperation.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package io.rsocket.kotlin.operation import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* @@ -37,10 +38,10 @@ internal class RequesterRequestStreamOperation( try { while (true) outbound.sendRequestN(responsePayloads.nextRequestN() ?: break) } catch (cause: Throwable) { - if (!currentCoroutineContext().isActive || !outbound.isClosed) throw cause + if (!currentCoroutineContext().isActive) throw cause } } catch (cause: Throwable) { - if (!outbound.isClosed) withContext(NonCancellable) { outbound.sendCancel() } + nonCancellable { outbound.sendCancel() } throw cause } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestChannelOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestChannelOperation.kt index 28828bad..9cb0ccba 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestChannelOperation.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestChannelOperation.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package io.rsocket.kotlin.operation import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* @@ -41,10 +42,8 @@ internal class ResponderRequestChannelOperation( outbound.sendRequestN(initialRequest) while (true) outbound.sendRequestN(requestPayloads.nextRequestN() ?: break) } catch (cause: Throwable) { - // ignore error if outbound was closed - TODO: recheck - if (this@coroutineScope.isActive && outbound.isClosed) return@launch // send cancel only if the operation is active - if (this@coroutineScope.isActive) withContext(NonCancellable) { outbound.sendCancel() } + if (this@coroutineScope.isActive) nonCancellable { outbound.sendCancel() } throw cause } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketConnection.kt index c49ee3b5..9ea1450d 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketConnection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,24 +16,17 @@ package io.rsocket.kotlin.transport +import kotlinx.coroutines.* import kotlinx.io.* // all methods can be called from any thread/context at any time // should be accessed only internally // should be implemented only by transports @RSocketTransportApi -public sealed interface RSocketConnection - -@RSocketTransportApi -public fun interface RSocketConnectionHandler { - public suspend fun handleConnection(connection: RSocketConnection) -} +public sealed interface RSocketConnection : CoroutineScope @RSocketTransportApi public interface RSocketSequentialConnection : RSocketConnection { - // TODO: is it needed for connection? - public val isClosedForSend: Boolean - // throws if frame not sent // streamId=0 should be sent earlier public suspend fun sendFrame(streamId: Int, frame: Buffer) @@ -47,9 +40,8 @@ public interface RSocketMultiplexedConnection : RSocketConnection { public suspend fun createStream(): Stream public suspend fun acceptStream(): Stream? - public interface Stream : AutoCloseable { - public val isClosedForSend: Boolean - + @RSocketTransportApi + public interface Stream : CoroutineScope { // 0 - highest priority // Int.MAX - lowest priority public fun setSendPriority(priority: Int) @@ -59,10 +51,5 @@ public interface RSocketMultiplexedConnection : RSocketConnection { // null if no more frames could be received public suspend fun receiveFrame(): Buffer? - - // closing stream will send buffered frames (if needed) - // sending/receiving frames will be not possible after it - // should not throw - override fun close() } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransport.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransport.kt index 0a5e73b1..f44cba1e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransport.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,17 +44,15 @@ public interface RSocketTransport : CoroutineScope { @SubclassOptInRequired(RSocketTransportApi::class) public interface RSocketClientTarget : CoroutineScope { - // cancelling Job will cancel connection - // Job will be completed when the connection is finished @RSocketTransportApi - public fun connectClient(handler: RSocketConnectionHandler): Job + public suspend fun connectClient(): RSocketConnection } @SubclassOptInRequired(RSocketTransportApi::class) public interface RSocketServerTarget : CoroutineScope { - // handler will be called for all new connections + // onConnection shouldn't throw. @RSocketTransportApi - public suspend fun startServer(handler: RSocketConnectionHandler): Instance + public suspend fun startServer(onConnection: (RSocketConnection) -> Unit): Instance } // cancelling it will cancel server diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue.kt index 158148cd..062848aa 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,9 +26,9 @@ import kotlinx.io.* private val selectFrame: suspend (ChannelResult) -> ChannelResult = { it } @RSocketTransportApi -public class PrioritizationFrameQueue(buffersCapacity: Int) { - private val priorityFrames = bufferChannel(buffersCapacity) - private val normalFrames = bufferChannel(buffersCapacity) +public class PrioritizationFrameQueue { + private val priorityFrames = bufferChannel(Channel.BUFFERED) + private val normalFrames = bufferChannel(Channel.BUFFERED) private val priorityOnReceive = priorityFrames.onReceiveCatching private val normalOnReceive = normalFrames.onReceiveCatching diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt index 86d7b8cd..41780838 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,10 +36,8 @@ class ConnectionEstablishmentTest : SuspendTest { override val coroutineContext: CoroutineContext, private val connection: RSocketConnection, ) : RSocketServerTarget { - override suspend fun startServer(handler: RSocketConnectionHandler): TestInstance { - return TestInstance(async { - handler.handleConnection(connection) - }) + override suspend fun startServer(onConnection: (RSocketConnection) -> Unit): TestInstance { + return TestInstance(async { onConnection(connection) }) } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/TestConnection.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/TestConnection.kt index d2f1d851..e3ad857b 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/TestConnection.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/TestConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,13 +44,7 @@ class TestConnection : RSocketSequentialConnection, RSocketClientTarget { } } - override fun connectClient(handler: RSocketConnectionHandler): Job = launch { - handler.handleConnection(this@TestConnection) - }.onCompletion { - if (it != null) job.completeExceptionally(it) - } - - override val isClosedForSend: Boolean get() = sendChannel.isClosedForSend + override suspend fun connectClient(): RSocketConnection = this override suspend fun sendFrame(streamId: Int, frame: Buffer) { sendChannel.send(frame) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt index 56f59b82..0f60c236 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt @@ -411,11 +411,10 @@ class RSocketRequesterTest : TestWithConnection() { fun cancelRequesterToCloseConnection() = test { val request = requester.requestStream(Payload.Empty).produceIn(GlobalScope) connection.test { - awaitFrame { frame -> - assertTrue(frame is RequestFrame) - } + awaitFrame { assertIs(it) } requester.cancel() //cancel requester - awaitFrame { assertTrue(it is ErrorFrame) } + awaitFrame { assertIs(it) } // cancelled stream + awaitFrame { assertIs(it) } // error on connection awaitError() } delay(100) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt index 52cd7c7d..af7e043d 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,10 +41,8 @@ class RSocketResponderRequestNTest : TestWithConnection() { override val coroutineContext: CoroutineContext, private val connection: RSocketConnection, ) : RSocketServerTarget { - override suspend fun startServer(handler: RSocketConnectionHandler): TestInstance { - return TestInstance(async { - handler.handleConnection(connection) - }) + override suspend fun startServer(onConnection: (RSocketConnection) -> Unit): TestInstance { + return TestInstance(async { onConnection(connection) }) } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt index c4fcf63f..98729428 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package io.rsocket.kotlin.keepalive +import app.cash.turbine.* import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.test.* @@ -100,8 +101,17 @@ class KeepAliveTest : TestWithConnection() { fun rSocketCanceledOnMissingKeepAliveTicks() = test { val rSocket = requester() connection.test { - while (rSocket.isActive) awaitFrame { it is KeepAliveFrame } - awaitError() + while (true) { + when (val event = awaitEvent()) { + is Event.Item<*> -> assertIs(event.value) + is Event.Error -> { + assertIs(event.throwable) + break + } + + Event.Complete -> error("Complete should not happen") + } + } } @OptIn(InternalCoroutinesApi::class) assertTrue(rSocket.coroutineContext.job.getCancellationException().cause is RSocketError.ConnectionError) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/OperationOutboundTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/OperationOutboundTest.kt index b1d7eb7e..d0ceab50 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/OperationOutboundTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/OperationOutboundTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,8 +30,6 @@ class OperationOutboundTest : SuspendTest { maxFragmentSize: Int, ) : OperationOutbound(streamId, FrameCodec(maxFragmentSize)) { val frames = bufferChannel(Channel.BUFFERED) - override val isClosed: Boolean get() = frames.isClosedForSend - override suspend fun sendFrame(frame: Buffer) { frames.send(frame) } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueueTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueueTest.kt index 624d625b..8e7c7b0c 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueueTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueueTest.kt @@ -19,12 +19,11 @@ package io.rsocket.kotlin.transport.internal import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.test.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* import kotlinx.io.* import kotlin.test.* class PrioritizationFrameQueueTest : SuspendTest { - private val queue = PrioritizationFrameQueue(Channel.BUFFERED) + private val queue = PrioritizationFrameQueue() @Test fun testOrdering() = test { diff --git a/rsocket-internal-io/api/rsocket-internal-io.api b/rsocket-internal-io/api/rsocket-internal-io.api index 4a0099ad..240e208c 100644 --- a/rsocket-internal-io/api/rsocket-internal-io.api +++ b/rsocket-internal-io/api/rsocket-internal-io.api @@ -6,9 +6,7 @@ public final class io/rsocket/kotlin/internal/io/ChannelsKt { public final class io/rsocket/kotlin/internal/io/ContextKt { public static final fun childContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; - public static final fun ensureActive (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;)V - public static final fun launchCoroutine (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static synthetic fun launchCoroutine$default (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun nonCancellable (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun onCompletion (Lkotlinx/coroutines/Job;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/Job; public static final fun supervisorContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; } diff --git a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt index a6d05d4c..e05c1eb0 100644 --- a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt +++ b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt @@ -24,23 +24,9 @@ public expect val Dispatchers.IoCompatible: CoroutineDispatcher public fun CoroutineContext.supervisorContext(): CoroutineContext = plus(SupervisorJob(get(Job))) public fun CoroutineContext.childContext(): CoroutineContext = plus(Job(get(Job))) +public suspend fun nonCancellable(block: suspend CoroutineScope.() -> T): T = withContext(NonCancellable, block) + public fun T.onCompletion(handler: CompletionHandler): T { invokeOnCompletion(handler) return this } - -public inline fun CoroutineContext.ensureActive(onInactive: () -> Unit) { - if (isActive) return - onInactive() // should not throw - ensureActive() // will throw -} - -@Suppress("SuspendFunctionOnCoroutineScope") -public suspend inline fun CoroutineScope.launchCoroutine( - context: CoroutineContext = EmptyCoroutineContext, - crossinline block: suspend (CancellableContinuation) -> Unit, -): T = suspendCancellableCoroutine { cont -> - val job = launch(context) { block(cont) } - job.invokeOnCompletion { if (it != null && cont.isActive) cont.resumeWithException(it) } - cont.invokeOnCancellation { job.cancel("launchCoroutine was cancelled", it) } -} diff --git a/rsocket-transport-tests/src/commonMain/kotlin/io/rsocket/kotlin/transport/tests/TransportTest.kt b/rsocket-transport-tests/src/commonMain/kotlin/io/rsocket/kotlin/transport/tests/TransportTest.kt index 9840a32d..8ee8a2a4 100644 --- a/rsocket-transport-tests/src/commonMain/kotlin/io/rsocket/kotlin/transport/tests/TransportTest.kt +++ b/rsocket-transport-tests/src/commonMain/kotlin/io/rsocket/kotlin/transport/tests/TransportTest.kt @@ -55,6 +55,8 @@ abstract class TransportTest : SuspendTest { SERVER.startServer(serverTransport, ACCEPTOR) override suspend fun after() { + // TODO: we do need delays in FAF and MP tests because in reality, here, we don't wait for the connection to be completed + // and so we start to close connection from both ends simultaneously client.coroutineContext.job.cancelAndJoin() testJob.cancelAndJoin() } @@ -62,21 +64,25 @@ abstract class TransportTest : SuspendTest { @Test fun fireAndForget10() = test { (1..10).map { async { client.fireAndForget(payload(it)) } }.awaitAll() + delay(100) } @Test open fun largePayloadFireAndForget10() = test { (1..10).map { async { client.fireAndForget(requesterLargePayload) } }.awaitAll() + delay(100) } @Test fun metadataPush10() = test { (1..10).map { async { client.metadataPush(packet(requesterData)) } }.awaitAll() + delay(100) } @Test open fun largePayloadMetadataPush10() = test { (1..10).map { async { client.metadataPush(packet(requesterLargeData)) } }.awaitAll() + delay(100) } @Test @@ -133,7 +139,6 @@ abstract class TransportTest : SuspendTest { } @Test - @Ignore //flaky, ignore for now fun requestChannel200000() = test { val request = flow { repeat(200_000) { emit(payload(it)) } @@ -162,7 +167,6 @@ abstract class TransportTest : SuspendTest { } @Test - @Ignore //flaky, ignore for now fun requestChannel256x512() = test { val request = flow { repeat(512) { @@ -187,17 +191,6 @@ abstract class TransportTest : SuspendTest { }.awaitAll() } - @Test - @Ignore //flaky, ignore for now - fun requestStreamX256() = test { - (0..256).map { - async { - val count = client.requestStream(payload(0)).onEach { it.close() }.count() - assertEquals(8192, count) - } - }.awaitAll() - } - @Test fun requestChannel500NoLeak() = test { val request = flow { @@ -237,15 +230,17 @@ abstract class TransportTest : SuspendTest { } @Test - @Ignore // windows - fun requestResponse10000() = test { - (1..10000).map { async { client.requestResponse(payload(3)).let(Companion::checkPayload) } }.awaitAll() + fun requestResponse10000Sequential() = test { + repeat(10000) { + client.requestResponse(payload(3)).let(Companion::checkPayload) + } } @Test - @Ignore // QUIC - fun requestResponse100000() = test { - repeat(100000) { client.requestResponse(payload(3)).let(Companion::checkPayload) } + fun requestResponse10000Parallel() = test { + repeat(10000) { + launch { client.requestResponse(payload(3)).let(Companion::checkPayload) } + } } @Test @@ -258,7 +253,7 @@ abstract class TransportTest : SuspendTest { @Test fun requestStream8K() = test { val count = client.requestStream(payload(3)).onEach { checkPayload(it) }.count() - assertEquals(8192, count) // TODO + assertEquals(8192, count) } @Test diff --git a/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api b/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api index d6153045..8be6747a 100644 --- a/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api +++ b/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api @@ -8,10 +8,7 @@ public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$F } public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { - public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V - public fun inheritDispatcher ()V public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V - public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V } @@ -31,10 +28,7 @@ public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$F } public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { - public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V - public fun inheritDispatcher ()V public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V - public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V } diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt index 388b05eb..164644e4 100644 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt @@ -34,56 +34,44 @@ public sealed interface KtorTcpClientTransport : RSocketTransport { @OptIn(RSocketTransportApi::class) public sealed interface KtorTcpClientTransportBuilder : RSocketTransportBuilder { - public fun dispatcher(context: CoroutineContext) - public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) - - public fun selectorManagerDispatcher(context: CoroutineContext) public fun selectorManager(manager: SelectorManager, manage: Boolean) - public fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) - //TODO: TLS support } private class KtorTcpClientTransportBuilderImpl : KtorTcpClientTransportBuilder { - private var dispatcher: CoroutineContext = Dispatchers.Default - private var selector: KtorTcpSelector = KtorTcpSelector.FromContext(Dispatchers.IoCompatible) + private var selectorManager: SelectorManager? = null + private var manageSelectorManager: Boolean = true private var socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit = {} - override fun dispatcher(context: CoroutineContext) { - check(context[Job] == null) { "Dispatcher shouldn't contain job" } - this.dispatcher = context - } - override fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) { this.socketOptions = block } - override fun selectorManagerDispatcher(context: CoroutineContext) { - check(context[Job] == null) { "Dispatcher shouldn't contain job" } - this.selector = KtorTcpSelector.FromContext(context) - } - override fun selectorManager(manager: SelectorManager, manage: Boolean) { - this.selector = KtorTcpSelector.FromInstance(manager, manage) + this.selectorManager = manager + this.manageSelectorManager = manage } @RSocketTransportApi - override fun buildTransport(context: CoroutineContext): KtorTcpClientTransport { - val transportContext = context.supervisorContext() + dispatcher - return KtorTcpClientTransportImpl( - coroutineContext = transportContext, - socketOptions = socketOptions, - selectorManager = selector.createFor(transportContext) - ) - } + override fun buildTransport(context: CoroutineContext): KtorTcpClientTransport = KtorTcpClientTransportImpl( + coroutineContext = context.supervisorContext() + Dispatchers.Default, + socketOptions = socketOptions, + selectorManager = selectorManager ?: SelectorManager(Dispatchers.IoCompatible), + manageSelectorManager = manageSelectorManager + ) } private class KtorTcpClientTransportImpl( override val coroutineContext: CoroutineContext, private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit, private val selectorManager: SelectorManager, + manageSelectorManager: Boolean, ) : KtorTcpClientTransport { + init { + if (manageSelectorManager) coroutineContext.job.invokeOnCompletion { selectorManager.close() } + } + override fun target(remoteAddress: SocketAddress): RSocketClientTarget = KtorTcpClientTargetImpl( coroutineContext = coroutineContext.supervisorContext(), socketOptions = socketOptions, @@ -101,10 +89,17 @@ private class KtorTcpClientTargetImpl( private val selectorManager: SelectorManager, private val remoteAddress: SocketAddress, ) : RSocketClientTarget { - @RSocketTransportApi - override fun connectClient(handler: RSocketConnectionHandler): Job = launch { - val socket = aSocket(selectorManager).tcp().connect(remoteAddress, socketOptions) - handler.handleKtorTcpConnection(socket) + override suspend fun connectClient(): RSocketConnection { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + return withContext(Dispatchers.IoCompatible) { + val socket = aSocket(selectorManager).tcp().connect(remoteAddress, socketOptions) + KtorTcpConnection( + parentContext = this@KtorTcpClientTargetImpl.coroutineContext, + socket = socket + ) + } } } diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt index a07d44e3..12375982 100644 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,64 +24,73 @@ import io.rsocket.kotlin.transport.internal.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.io.* +import kotlin.coroutines.* @RSocketTransportApi -internal suspend fun RSocketConnectionHandler.handleKtorTcpConnection(socket: Socket): Unit = coroutineScope { - val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) - val inbound = bufferChannel(Channel.BUFFERED) +internal class KtorTcpConnection( + parentContext: CoroutineContext, + private val socket: Socket, +) : RSocketSequentialConnection { + private val outboundQueue = PrioritizationFrameQueue() + private val inbound = bufferChannel(Channel.BUFFERED) - val readerJob = launch { - val input = socket.openReadChannel() - try { - while (true) inbound.send(input.readFrame() ?: break) - input.cancel(null) - } catch (cause: Throwable) { - input.cancel(cause) - throw cause - } - }.onCompletion { inbound.cancel() } + override val coroutineContext: CoroutineContext = parentContext.childContext() - val writerJob = launch { - val output = socket.openWriteChannel() - try { - while (true) { - // we write all available frames here, and only after it flush - // in this case, if there are several buffered frames we can send them in one go - // avoiding unnecessary flushes - output.writeFrame(outboundQueue.dequeueFrame() ?: break) - while (true) output.writeFrame(outboundQueue.tryDequeueFrame() ?: break) - output.flush() + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + val outboundJob = launch { + nonCancellable { + val output = socket.openWriteChannel() + try { + while (true) { + // we write all available frames here, and only after it flush + // in this case, if there are several buffered frames we can send them in one go + // avoiding unnecessary flushes + output.writeFrame(outboundQueue.dequeueFrame() ?: break) + while (true) output.writeFrame(outboundQueue.tryDequeueFrame() ?: break) + output.flush() + } + output.flushAndClose() + } catch (cause: Throwable) { + output.cancel(cause) + throw cause + } + } + }.onCompletion { + outboundQueue.cancel() } - output.close(null) - } catch (cause: Throwable) { - output.close(cause) - throw cause - } - }.onCompletion { outboundQueue.cancel() } - try { - handleConnection(KtorTcpConnection(outboundQueue, inbound)) - } finally { - readerJob.cancel() - outboundQueue.close() // will cause `writerJob` completion - // even if it was cancelled, we still need to close socket and await it closure - withContext(NonCancellable) { - // await completion of read/write and then close socket - readerJob.join() - writerJob.join() - // close socket - socket.close() - socket.socketContext.join() + val inboundJob = launch { + val input = socket.openReadChannel() + try { + while (true) { + inbound.send(input.readFrame() ?: break) + } + input.cancel(null) + } catch (cause: Throwable) { + input.cancel(cause) + throw cause + } + }.onCompletion { + inbound.cancel() + } + + try { + awaitCancellation() + } finally { + nonCancellable { + outboundQueue.close() + outboundJob.join() + inboundJob.join() + // await socket completion + socket.close() + socket.socketContext.join() + } + } } } -} -@RSocketTransportApi -private class KtorTcpConnection( - private val outboundQueue: PrioritizationFrameQueue, - private val inbound: ReceiveChannel, -) : RSocketSequentialConnection { - override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend override suspend fun sendFrame(streamId: Int, frame: Buffer) { return outboundQueue.enqueueFrame(streamId, frame) } @@ -89,21 +98,21 @@ private class KtorTcpConnection( override suspend fun receiveFrame(): Buffer? { return inbound.receiveCatching().getOrNull() } -} -@OptIn(InternalAPI::class) -private fun ByteWriteChannel.writeFrame(frame: Buffer) { - writeBuffer.writeInt24(frame.size.toInt()) - writeBuffer.transferFrom(frame) -} + @OptIn(InternalAPI::class) + private fun ByteWriteChannel.writeFrame(frame: Buffer) { + writeBuffer.writeInt24(frame.size.toInt()) + writeBuffer.transferFrom(frame) + } -@OptIn(InternalAPI::class) -private suspend fun ByteReadChannel.readFrame(): Buffer? { - while (availableForRead < 3 && awaitContent(3)) yield() - if (availableForRead == 0) return null + @OptIn(InternalAPI::class) + private suspend fun ByteReadChannel.readFrame(): Buffer? { + while (availableForRead < 3 && awaitContent(3)) yield() + if (availableForRead == 0) return null - val length = readBuffer.readInt24() - return readBuffer(length).also { - it.require(length.toLong()) + val length = readBuffer.readInt24() + return readBuffer(length).also { + it.require(length.toLong()) + } } } diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt deleted file mode 100644 index 161752e0..00000000 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.ktor.tcp - -import io.ktor.network.selector.* -import kotlinx.coroutines.* -import kotlin.coroutines.* - -internal sealed class KtorTcpSelector { - class FromContext(val context: CoroutineContext) : KtorTcpSelector() - class FromInstance(val selectorManager: SelectorManager, val manage: Boolean) : KtorTcpSelector() -} - -internal fun KtorTcpSelector.createFor(parentContext: CoroutineContext): SelectorManager { - val selectorManager: SelectorManager - val manage: Boolean - when (this) { - is KtorTcpSelector.FromContext -> { - selectorManager = SelectorManager(parentContext + context) - manage = true - } - - is KtorTcpSelector.FromInstance -> { - selectorManager = this.selectorManager - manage = this.manage - } - } - if (manage) Job(parentContext.job).invokeOnCompletion { selectorManager.close() } - return selectorManager -} diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt index b15fa855..8a01e152 100644 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt @@ -28,10 +28,12 @@ public sealed interface KtorTcpServerInstance : RSocketServerInstance { public val localAddress: SocketAddress } +public typealias KtorTcpServerTarget = RSocketServerTarget + @OptIn(RSocketTransportApi::class) public sealed interface KtorTcpServerTransport : RSocketTransport { - public fun target(localAddress: SocketAddress? = null): RSocketServerTarget - public fun target(host: String = "0.0.0.0", port: Int = 0): RSocketServerTarget + public fun target(localAddress: SocketAddress? = null): KtorTcpServerTarget + public fun target(host: String = "0.0.0.0", port: Int = 0): KtorTcpServerTarget public companion object Factory : RSocketTransportFactory(::KtorTcpServerTransportBuilderImpl) @@ -39,62 +41,51 @@ public sealed interface KtorTcpServerTransport : RSocketTransport { @OptIn(RSocketTransportApi::class) public sealed interface KtorTcpServerTransportBuilder : RSocketTransportBuilder { - public fun dispatcher(context: CoroutineContext) - public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) - - public fun selectorManagerDispatcher(context: CoroutineContext) public fun selectorManager(manager: SelectorManager, manage: Boolean) - public fun socketOptions(block: SocketOptions.AcceptorOptions.() -> Unit) } private class KtorTcpServerTransportBuilderImpl : KtorTcpServerTransportBuilder { - private var dispatcher: CoroutineContext = Dispatchers.Default - private var selector: KtorTcpSelector = KtorTcpSelector.FromContext(Dispatchers.IoCompatible) + private var selectorManager: SelectorManager? = null + private var manageSelectorManager: Boolean = true private var socketOptions: SocketOptions.AcceptorOptions.() -> Unit = {} - override fun dispatcher(context: CoroutineContext) { - check(context[Job] == null) { "Dispatcher shouldn't contain job" } - this.dispatcher = context - } - override fun socketOptions(block: SocketOptions.AcceptorOptions.() -> Unit) { this.socketOptions = block } - override fun selectorManagerDispatcher(context: CoroutineContext) { - check(context[Job] == null) { "Dispatcher shouldn't contain job" } - this.selector = KtorTcpSelector.FromContext(context) - } - override fun selectorManager(manager: SelectorManager, manage: Boolean) { - this.selector = KtorTcpSelector.FromInstance(manager, manage) + this.selectorManager = manager + this.manageSelectorManager = manage } @RSocketTransportApi - override fun buildTransport(context: CoroutineContext): KtorTcpServerTransport { - val transportContext = context.supervisorContext() + dispatcher - return KtorTcpServerTransportImpl( - coroutineContext = transportContext, - socketOptions = socketOptions, - selectorManager = selector.createFor(transportContext) - ) - } + override fun buildTransport(context: CoroutineContext): KtorTcpServerTransport = KtorTcpServerTransportImpl( + coroutineContext = context.supervisorContext() + Dispatchers.Default, + socketOptions = socketOptions, + selectorManager = selectorManager ?: SelectorManager(Dispatchers.IoCompatible), + manageSelectorManager = manageSelectorManager + ) } private class KtorTcpServerTransportImpl( override val coroutineContext: CoroutineContext, private val socketOptions: SocketOptions.AcceptorOptions.() -> Unit, private val selectorManager: SelectorManager, + manageSelectorManager: Boolean, ) : KtorTcpServerTransport { - override fun target(localAddress: SocketAddress?): RSocketServerTarget = KtorTcpServerTargetImpl( + init { + if (manageSelectorManager) coroutineContext.job.invokeOnCompletion { selectorManager.close() } + } + + override fun target(localAddress: SocketAddress?): KtorTcpServerTarget = KtorTcpServerTargetImpl( coroutineContext = coroutineContext.supervisorContext(), socketOptions = socketOptions, selectorManager = selectorManager, localAddress = localAddress ) - override fun target(host: String, port: Int): RSocketServerTarget = target(InetSocketAddress(host, port)) + override fun target(host: String, port: Int): KtorTcpServerTarget = target(InetSocketAddress(host, port)) } @OptIn(RSocketTransportApi::class) @@ -103,52 +94,54 @@ private class KtorTcpServerTargetImpl( private val socketOptions: SocketOptions.AcceptorOptions.() -> Unit, private val selectorManager: SelectorManager, private val localAddress: SocketAddress?, -) : RSocketServerTarget { +) : KtorTcpServerTarget { @RSocketTransportApi - override suspend fun startServer(handler: RSocketConnectionHandler): KtorTcpServerInstance { + override suspend fun startServer(onConnection: (RSocketConnection) -> Unit): KtorTcpServerInstance { currentCoroutineContext().ensureActive() coroutineContext.ensureActive() - return startKtorTcpServer(this, bindSocket(), handler) - } - private suspend fun bindSocket(): ServerSocket = launchCoroutine { cont -> - val socket = aSocket(selectorManager).tcp().bind(localAddress, socketOptions) - cont.resume(socket) { _, value, _ -> value.close() } + return withContext(Dispatchers.IoCompatible) { + val serverSocket = aSocket(selectorManager).tcp().bind(localAddress, socketOptions) + KtorTcpServerInstanceImpl( + coroutineContext = this@KtorTcpServerTargetImpl.coroutineContext.childContext(), + serverSocket = serverSocket, + onConnection = onConnection + ) + } } } @RSocketTransportApi -private fun startKtorTcpServer( - scope: CoroutineScope, - serverSocket: ServerSocket, - handler: RSocketConnectionHandler, -): KtorTcpServerInstance { - val serverJob = scope.launch { - try { - // the failure of one connection should not stop all other connections - supervisorScope { +private class KtorTcpServerInstanceImpl( + override val coroutineContext: CoroutineContext, + private val serverSocket: ServerSocket, + private val onConnection: (RSocketConnection) -> Unit, +) : KtorTcpServerInstance { + override val localAddress: SocketAddress get() = serverSocket.localAddress + + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + try { + currentCoroutineContext().ensureActive() // because of ATOMIC start + + val connectionsContext = currentCoroutineContext().supervisorContext() while (true) { val socket = serverSocket.accept() - launch { handler.handleKtorTcpConnection(socket) } + onConnection( + KtorTcpConnection( + parentContext = connectionsContext, + socket = socket + ) + ) + } + } finally { + nonCancellable { + serverSocket.close() + serverSocket.socketContext.join() } - } - } finally { - // even if it was cancelled, we still need to close socket and await it closure - withContext(NonCancellable) { - serverSocket.close() - serverSocket.socketContext.join() } } } - return KtorTcpServerInstanceImpl( - coroutineContext = scope.coroutineContext + serverJob, - localAddress = serverSocket.localAddress - ) } - -@RSocketTransportApi -private class KtorTcpServerInstanceImpl( - override val coroutineContext: CoroutineContext, - override val localAddress: SocketAddress, -) : KtorTcpServerInstance diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpConnection.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpConnection.kt index 98347f0e..1dec6548 100644 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpConnection.kt +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,6 +51,8 @@ internal class TcpConnection( } } } + }.invokeOnCompletion { + sendChannel.cancelWithCause(it) } launch { socketConnection.input.apply { @@ -65,10 +67,10 @@ internal class TcpConnection( } } } + }.invokeOnCompletion { + receiveChannel.cancelWithCause(it) } coroutineContext.job.invokeOnCompletion { - sendChannel.cancelWithCause(it) - receiveChannel.cancelWithCause(it) socketConnection.input.cancel(it) socketConnection.output.close(it) socketConnection.socket.close() diff --git a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt index c51bf418..68cd47ab 100644 --- a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt +++ b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt @@ -21,8 +21,10 @@ import io.ktor.network.sockets.* import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.transport.tests.* import kotlinx.coroutines.* +import kotlin.test.* @Suppress("DEPRECATION_ERROR") +@Ignore class TcpTransportTest : TransportTest() { override suspend fun before() { val serverSocket = startServer(TcpServerTransport("127.0.0.1")).serverSocket.await() diff --git a/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt b/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt index cc91d10d..418a1f32 100644 --- a/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt +++ b/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -79,15 +79,15 @@ private class KtorWebSocketClientTransportBuilderImpl : KtorWebSocketClientTrans install(WebSockets, webSocketsConfig) } // only dispatcher of a client is used - it looks like it's Dispatchers.IO now - val newContext = context.supervisorContext() + (httpClient.coroutineContext[ContinuationInterceptor] ?: EmptyCoroutineContext) - val newJob = newContext.job + val transportContext = context.supervisorContext() + Dispatchers.Default + val transportJob = transportContext.job val httpClientJob = httpClient.coroutineContext.job - httpClientJob.invokeOnCompletion { newJob.cancel("HttpClient closed", it) } - newJob.invokeOnCompletion { httpClientJob.cancel("KtorWebSocketClientTransport closed", it) } + httpClientJob.invokeOnCompletion { transportJob.cancel("HttpClient closed", it) } + transportJob.invokeOnCompletion { httpClientJob.cancel("KtorWebSocketClientTransport closed", it) } return KtorWebSocketClientTransportImpl( - coroutineContext = newContext, + coroutineContext = transportContext, httpClient = httpClient, ) } @@ -98,7 +98,7 @@ private class KtorWebSocketClientTransportImpl( private val httpClient: HttpClient, ) : KtorWebSocketClientTransport { override fun target(request: HttpRequestBuilder.() -> Unit): RSocketClientTarget = KtorWebSocketClientTargetImpl( - coroutineContext = coroutineContext, + coroutineContext = coroutineContext.supervisorContext(), httpClient = httpClient, request = request ) @@ -136,12 +136,17 @@ private class KtorWebSocketClientTargetImpl( private val httpClient: HttpClient, private val request: HttpRequestBuilder.() -> Unit, ) : RSocketClientTarget { - @RSocketTransportApi - override fun connectClient(handler: RSocketConnectionHandler): Job = launch { - httpClient.webSocket(request) { - handler.handleKtorWebSocketConnection(this) + override suspend fun connectClient(): RSocketConnection { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val session = httpClient.webSocketSession(request) + val handle = coroutineContext.job.invokeOnCompletion { + session.cancel("Transport was cancelled", it) } + session.coroutineContext.job.invokeOnCompletion { handle.dispose() } + return KtorWebSocketConnection(session) } } diff --git a/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api b/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api index 19f78978..d0ddd4c4 100644 --- a/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api +++ b/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api @@ -1,5 +1,8 @@ -public final class io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnectionKt { - public static final fun handleKtorWebSocketConnection (Lio/rsocket/kotlin/transport/RSocketConnectionHandler;Lio/ktor/websocket/WebSocketSession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +public final class io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection : io/rsocket/kotlin/transport/RSocketSequentialConnection { + public fun (Lio/ktor/websocket/WebSocketSession;)V + public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext; + public fun receiveFrame (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun sendFrame (ILkotlinx/io/Buffer;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class io/rsocket/kotlin/transport/ktor/websocket/internal/WebSocketConnection : io/rsocket/kotlin/Connection, kotlinx/coroutines/CoroutineScope { diff --git a/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt b/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt index 50447626..b5cfc27c 100644 --- a/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt +++ b/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,43 +21,54 @@ import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.internal.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* import kotlinx.io.* +import kotlin.coroutines.* @RSocketTransportApi -public suspend fun RSocketConnectionHandler.handleKtorWebSocketConnection(webSocketSession: WebSocketSession): Unit = coroutineScope { - val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) +public class KtorWebSocketConnection( + private val session: WebSocketSession, +) : RSocketSequentialConnection { + private val outboundQueue = PrioritizationFrameQueue() + override val coroutineContext: CoroutineContext get() = session.coroutineContext - val senderJob = launch { - while (true) webSocketSession.send(outboundQueue.dequeueFrame()?.readByteArray() ?: break) - }.onCompletion { outboundQueue.cancel() } + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + val outboundJob = launch { + nonCancellable { + try { + while (true) { + session.send(outboundQueue.dequeueFrame()?.readByteArray() ?: break) + } + } catch (cause: Throwable) { + session.outgoing.close(cause) + throw cause + } finally { + outboundQueue.cancel() + } + } + } - try { - handleConnection(KtorWebSocketConnection(outboundQueue, webSocketSession.incoming)) - } finally { - webSocketSession.incoming.cancel() - outboundQueue.close() - withContext(NonCancellable) { - senderJob.join() // await all frames sent - webSocketSession.close() - webSocketSession.coroutineContext.job.join() + try { + awaitCancellation() + } finally { + nonCancellable { + session.incoming.cancel() + outboundQueue.close() + outboundJob.join() + // await socket completion + session.close() + } + } } } -} - -@RSocketTransportApi -private class KtorWebSocketConnection( - private val outboundQueue: PrioritizationFrameQueue, - private val inbound: ReceiveChannel, -) : RSocketSequentialConnection { - override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend override suspend fun sendFrame(streamId: Int, frame: Buffer) { return outboundQueue.enqueueFrame(streamId, frame) } override suspend fun receiveFrame(): Buffer? { - val frame = inbound.receiveCatching().getOrNull() ?: return null + val frame = session.incoming.receiveCatching().getOrNull() ?: return null return Buffer().apply { write(frame.data) } } } diff --git a/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt index b55e5b7a..f816d128 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt @@ -92,8 +92,7 @@ private class KtorWebSocketServerTransportBuilderImpl : KtorWebSocketServerTrans @RSocketTransportApi override fun buildTransport(context: CoroutineContext): KtorWebSocketServerTransport = KtorWebSocketServerTransportImpl( - // we always add IO - as it's the best choice here, server will use it's own dispatcher anyway - coroutineContext = context.supervisorContext() + Dispatchers.IoCompatible, + coroutineContext = context.supervisorContext() + Dispatchers.Default, factory = requireNotNull(httpServerFactory) { "httpEngine is required" }, webSocketsConfig = webSocketsConfig, ) @@ -151,12 +150,12 @@ private class KtorWebSocketServerTargetImpl( ) : RSocketServerTarget { @RSocketTransportApi - override suspend fun startServer(handler: RSocketConnectionHandler): KtorWebSocketServerInstance { + override suspend fun startServer(onConnection: (RSocketConnection) -> Unit): KtorWebSocketServerInstance { currentCoroutineContext().ensureActive() coroutineContext.ensureActive() val serverContext = coroutineContext.childContext() - val embeddedServer = createServer(handler, serverContext) + val embeddedServer = createServer(serverContext, onConnection) val resolvedConnectors = startServer(embeddedServer, serverContext) return KtorWebSocketServerInstanceImpl( @@ -170,8 +169,8 @@ private class KtorWebSocketServerTargetImpl( // parentCoroutineContext is the context of server instance @RSocketTransportApi private fun createServer( - handler: RSocketConnectionHandler, serverContext: CoroutineContext, + onConnection: (RSocketConnection) -> Unit, ): EmbeddedServer<*, *> { val config = serverConfig { val target = this@KtorWebSocketServerTargetImpl @@ -180,7 +179,8 @@ private class KtorWebSocketServerTargetImpl( install(WebSockets, webSocketsConfig) routing { webSocket(target.path, target.protocol) { - handler.handleKtorWebSocketConnection(this) + onConnection(KtorWebSocketConnection(this)) + awaitCancellation() } } } @@ -191,20 +191,24 @@ private class KtorWebSocketServerTargetImpl( private suspend fun startServer( embeddedServer: EmbeddedServer<*, *>, serverContext: CoroutineContext, - ): List = launchCoroutine(serverContext + Dispatchers.IoCompatible) { cont -> - embeddedServer.startSuspend() - launch(serverContext + Dispatchers.IoCompatible) { + ): List { + @OptIn(DelicateCoroutinesApi::class) + val serverJob = launch(serverContext, start = CoroutineStart.ATOMIC) { try { + currentCoroutineContext().ensureActive() // because of atomic start + embeddedServer.startSuspend() awaitCancellation() } finally { - withContext(NonCancellable) { + nonCancellable { embeddedServer.stopSuspend() } } } - cont.resume(embeddedServer.engine.resolvedConnectors()) { cause, _, _ -> - // will cause stopping of the server - serverContext.job.cancel("Cancelled", cause) + return try { + embeddedServer.engine.resolvedConnectors() + } catch (cause: Throwable) { + serverJob.cancel("Starting server cancelled", cause) + throw cause } } } diff --git a/rsocket-transports/ktor-websocket-tests/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/tests/WebSocketTransportTest.kt b/rsocket-transports/ktor-websocket-tests/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/tests/WebSocketTransportTest.kt index 670965ad..de402f7f 100644 --- a/rsocket-transports/ktor-websocket-tests/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/tests/WebSocketTransportTest.kt +++ b/rsocket-transports/ktor-websocket-tests/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/tests/WebSocketTransportTest.kt @@ -19,10 +19,12 @@ package io.rsocket.kotlin.transport.ktor.websocket.tests import io.rsocket.kotlin.transport.ktor.websocket.client.* import io.rsocket.kotlin.transport.ktor.websocket.server.* import io.rsocket.kotlin.transport.tests.* +import kotlin.test.* import io.ktor.client.engine.cio.CIO as ClientCIO import io.ktor.server.cio.CIO as ServerCIO @Suppress("DEPRECATION_ERROR") +@Ignore class WebSocketTransportTest : TransportTest() { override suspend fun before() { val embeddedServer = startServer( diff --git a/rsocket-transports/local/api/rsocket-transport-local.api b/rsocket-transports/local/api/rsocket-transport-local.api index 10c48ca0..3aaec585 100644 --- a/rsocket-transports/local/api/rsocket-transport-local.api +++ b/rsocket-transports/local/api/rsocket-transport-local.api @@ -8,7 +8,6 @@ public final class io/rsocket/kotlin/transport/local/LocalClientTransport$Factor public abstract interface class io/rsocket/kotlin/transport/local/LocalClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V - public fun inheritDispatcher ()V } public final class io/rsocket/kotlin/transport/local/LocalServer : io/rsocket/kotlin/transport/ClientTransport { @@ -35,10 +34,7 @@ public final class io/rsocket/kotlin/transport/local/LocalServerTransport$Factor public abstract interface class io/rsocket/kotlin/transport/local/LocalServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V - public fun inheritDispatcher ()V - public abstract fun multiplexed (II)V - public static synthetic fun multiplexed$default (Lio/rsocket/kotlin/transport/local/LocalServerTransportBuilder;IIILjava/lang/Object;)V - public abstract fun sequential (I)V - public static synthetic fun sequential$default (Lio/rsocket/kotlin/transport/local/LocalServerTransportBuilder;IILjava/lang/Object;)V + public abstract fun multiplexed ()V + public abstract fun sequential ()V } diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalClientTransport.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalClientTransport.kt index e782044c..be1fff68 100644 --- a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalClientTransport.kt +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalClientTransport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,11 +32,10 @@ public sealed interface LocalClientTransport : RSocketTransport { @OptIn(RSocketTransportApi::class) public sealed interface LocalClientTransportBuilder : RSocketTransportBuilder { public fun dispatcher(context: CoroutineContext) - public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) } private class LocalClientTransportBuilderImpl : LocalClientTransportBuilder { - private var dispatcher: CoroutineContext = Dispatchers.Default + private var dispatcher: CoroutineContext = Dispatchers.Unconfined override fun dispatcher(context: CoroutineContext) { check(context[Job] == null) { "Dispatcher shouldn't contain job" } @@ -63,10 +62,11 @@ private class LocalClientTargetImpl( override val coroutineContext: CoroutineContext, private val serverName: String, ) : RSocketClientTarget { - @RSocketTransportApi - override fun connectClient(handler: RSocketConnectionHandler): Job { + override suspend fun connectClient(): RSocketConnection { + currentCoroutineContext().ensureActive() coroutineContext.ensureActive() - return LocalServerRegistry.get(serverName).connect(clientScope = this, clientHandler = handler) + + return LocalServerRegistry.connectClient(serverName, coroutineContext) } } diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerConnector.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerConnector.kt index ce5d3086..08d0b138 100644 --- a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerConnector.kt +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerConnector.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,174 +22,159 @@ import io.rsocket.kotlin.transport.internal.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.io.* +import kotlin.coroutines.* internal sealed class LocalServerConnector { @RSocketTransportApi - abstract fun connect( - clientScope: CoroutineScope, - clientHandler: RSocketConnectionHandler, - serverScope: CoroutineScope, - serverHandler: RSocketConnectionHandler, - ): Job - - internal class Sequential( - private val prioritizationQueueBuffersCapacity: Int, - ) : LocalServerConnector() { - - @RSocketTransportApi - override fun connect( - clientScope: CoroutineScope, - clientHandler: RSocketConnectionHandler, - serverScope: CoroutineScope, - serverHandler: RSocketConnectionHandler, - ): Job { - val clientToServer = PrioritizationFrameQueue(prioritizationQueueBuffersCapacity) - val serverToClient = PrioritizationFrameQueue(prioritizationQueueBuffersCapacity) - - launchLocalConnection(serverScope, serverToClient, clientToServer, serverHandler) - return launchLocalConnection(clientScope, clientToServer, serverToClient, clientHandler) - } + abstract suspend fun connect( + clientContext: CoroutineContext, + serverContext: CoroutineContext, + onConnection: (RSocketConnection) -> Unit, + ): RSocketConnection + object Sequential : LocalServerConnector() { @RSocketTransportApi - private fun launchLocalConnection( - scope: CoroutineScope, - outbound: PrioritizationFrameQueue, - inbound: PrioritizationFrameQueue, - handler: RSocketConnectionHandler, - ): Job = scope.launch { - handler.handleConnection(Connection(outbound, inbound)) - }.onCompletion { - outbound.close() - inbound.cancel() + override suspend fun connect( + clientContext: CoroutineContext, + serverContext: CoroutineContext, + onConnection: (RSocketConnection) -> Unit, + ): RSocketConnection { + val frames = Frames() + onConnection(Connection(serverContext.childContext(), frames.clientToServer, frames.serverToClient)) + return Connection(clientContext.childContext(), frames.serverToClient, frames.clientToServer) } @RSocketTransportApi private class Connection( - private val outbound: PrioritizationFrameQueue, - private val inbound: PrioritizationFrameQueue, + override val coroutineContext: CoroutineContext, + private val incomingFrames: ReceiveChannel, + private val outgoingFrames: SendChannel, ) : RSocketSequentialConnection { - override val isClosedForSend: Boolean get() = outbound.isClosedForSend + private val outboundQueue = PrioritizationFrameQueue() + + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + launch { + nonCancellable { + while (true) outgoingFrames.send(outboundQueue.dequeueFrame() ?: break) + } + }.invokeOnCompletion { + outboundQueue.cancel() + outgoingFrames.close() + } + try { + awaitCancellation() + } finally { + outboundQueue.close() + incomingFrames.cancel() + } + } + } override suspend fun sendFrame(streamId: Int, frame: Buffer) { - return outbound.enqueueFrame(streamId, frame) + return outboundQueue.enqueueFrame(streamId, frame) } override suspend fun receiveFrame(): Buffer? { - return inbound.dequeueFrame() + return incomingFrames.receiveCatching().getOrNull() } } } - // TODO: better parameters naming - class Multiplexed( - private val streamsQueueCapacity: Int, - private val streamBufferCapacity: Int, - ) : LocalServerConnector() { + object Multiplexed : LocalServerConnector() { @RSocketTransportApi - override fun connect( - clientScope: CoroutineScope, - clientHandler: RSocketConnectionHandler, - serverScope: CoroutineScope, - serverHandler: RSocketConnectionHandler, - ): Job { - val streams = Streams(streamsQueueCapacity) - - launchLocalConnection(serverScope, streams.serverToClient, streams.clientToServer, serverHandler) - return launchLocalConnection(clientScope, streams.clientToServer, streams.serverToClient, clientHandler) - } - - @RSocketTransportApi - private fun launchLocalConnection( - scope: CoroutineScope, - outbound: SendChannel, - inbound: ReceiveChannel, - handler: RSocketConnectionHandler, - ): Job = scope.launch { - handler.handleConnection(Connection(SupervisorJob(coroutineContext.job), outbound, inbound, streamBufferCapacity)) - }.onCompletion { - outbound.close() - inbound.cancel() + override suspend fun connect( + clientContext: CoroutineContext, + serverContext: CoroutineContext, + onConnection: (RSocketConnection) -> Unit, + ): RSocketConnection { + val streams = Streams() + onConnection(Connection(serverContext.childContext(), streams.clientToServer, streams.serverToClient)) + return Connection(clientContext.childContext(), streams.serverToClient, streams.clientToServer) } @RSocketTransportApi private class Connection( - private val streamsJob: Job, - private val outbound: SendChannel, - private val inbound: ReceiveChannel, - private val streamBufferCapacity: Int, + override val coroutineContext: CoroutineContext, + private val incomingStreams: ReceiveChannel, + private val outgoingStreams: SendChannel, ) : RSocketMultiplexedConnection { - override suspend fun createStream(): RSocketMultiplexedConnection.Stream { - val frames = Frames(streamBufferCapacity) + private val streamsContext = coroutineContext.supervisorContext() - outbound.send(frames) + init { + coroutineContext.job.invokeOnCompletion { + outgoingStreams.close() + incomingStreams.cancel() + } + } + override suspend fun createStream(): RSocketMultiplexedConnection.Stream { + val frames = Frames() + outgoingStreams.send(frames) return Stream( - parentJob = streamsJob, - outbound = frames.clientToServer, - inbound = frames.serverToClient + coroutineContext = streamsContext.childContext(), + incoming = frames.clientToServer, + outgoing = frames.serverToClient ) } override suspend fun acceptStream(): RSocketMultiplexedConnection.Stream? { - val frames = inbound.receiveCatching().getOrNull() ?: return null - + val frames = incomingStreams.receiveCatching().getOrNull() ?: return null return Stream( - parentJob = streamsJob, - outbound = frames.serverToClient, - inbound = frames.clientToServer + coroutineContext = streamsContext.childContext(), + incoming = frames.serverToClient, + outgoing = frames.clientToServer ) } } @RSocketTransportApi private class Stream( - parentJob: Job, - private val outbound: SendChannel, - private val inbound: ReceiveChannel, + override val coroutineContext: CoroutineContext, + private val incoming: ReceiveChannel, + private val outgoing: SendChannel, ) : RSocketMultiplexedConnection.Stream { - private val streamJob = Job(parentJob).onCompletion { - outbound.close() - inbound.cancel() + init { + coroutineContext.job.invokeOnCompletion { + outgoing.close() + incoming.cancel() + } } - override fun close() { - streamJob.complete() + override fun setSendPriority(priority: Int) { + // no-op } - @OptIn(DelicateCoroutinesApi::class) - override val isClosedForSend: Boolean get() = outbound.isClosedForSend - - override fun setSendPriority(priority: Int) {} - override suspend fun sendFrame(frame: Buffer) { - return outbound.send(frame) + return outgoing.send(frame) } override suspend fun receiveFrame(): Buffer? { - return inbound.receiveCatching().getOrNull() + return incoming.receiveCatching().getOrNull() } } + } +} - private class Streams(bufferCapacity: Int) : AutoCloseable { - val clientToServer = channelForCloseable(bufferCapacity) - val serverToClient = channelForCloseable(bufferCapacity) +private class Streams : AutoCloseable { + val clientToServer = channelForCloseable(Channel.BUFFERED) + val serverToClient = channelForCloseable(Channel.BUFFERED) - // only for undelivered element case - override fun close() { - clientToServer.cancel() - serverToClient.cancel() - } - } + // only for undelivered element case + override fun close() { + clientToServer.cancel() + serverToClient.cancel() + } +} - private class Frames(bufferCapacity: Int) : AutoCloseable { - val clientToServer = bufferChannel(bufferCapacity) - val serverToClient = bufferChannel(bufferCapacity) +private class Frames : AutoCloseable { + val clientToServer = bufferChannel(Channel.BUFFERED) + val serverToClient = bufferChannel(Channel.BUFFERED) - // only for undelivered element case - override fun close() { - clientToServer.cancel() - serverToClient.cancel() - } - } + // only for undelivered element case + override fun close() { + clientToServer.cancel() + serverToClient.cancel() } } diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerInstanceImpl.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerInstanceImpl.kt deleted file mode 100644 index 2892d83f..00000000 --- a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerInstanceImpl.kt +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.local - -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.transport.* -import kotlinx.coroutines.* -import kotlin.coroutines.* - -internal class LocalServerInstanceImpl @RSocketTransportApi constructor( - override val serverName: String, - override val coroutineContext: CoroutineContext, - private val serverHandler: RSocketConnectionHandler, - private val connector: LocalServerConnector, -) : LocalServerInstance { - private val serverScope = CoroutineScope(coroutineContext.supervisorContext()) - - init { - LocalServerRegistry.register(serverName, this) - } - - @RSocketTransportApi - fun connect( - clientScope: CoroutineScope, - clientHandler: RSocketConnectionHandler, - ): Job { - coroutineContext.ensureActive() - - return connector.connect( - clientScope = clientScope, - clientHandler = clientHandler, - serverScope = serverScope, - serverHandler = serverHandler - ) - } -} diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerRegistry.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerRegistry.kt index b1246df9..c1aea80e 100644 --- a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerRegistry.kt +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerRegistry.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,24 +16,69 @@ package io.rsocket.kotlin.transport.local +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* import kotlinx.atomicfu.locks.* import kotlinx.coroutines.* +import kotlin.coroutines.* -internal object LocalServerRegistry { - private val lock = SynchronizedObject() +@RSocketTransportApi +internal object LocalServerRegistry : SynchronizedObject() { private val instances = mutableMapOf() - fun register(name: String, target: LocalServerInstanceImpl) { - synchronized(lock) { + private fun register(name: String, instance: LocalServerInstanceImpl) { + synchronized(this) { check(name !in instances) { "Already registered: $name" } - instances[name] = target + instances[name] = instance } - target.coroutineContext.job.invokeOnCompletion { - synchronized(lock) { instances.remove(name) } + instance.coroutineContext.job.invokeOnCompletion { + synchronized(this) { + instances.remove(name) + } } } - fun get(name: String): LocalServerInstanceImpl = synchronized(lock) { + private fun get(name: String): LocalServerInstanceImpl = synchronized(this) { checkNotNull(instances[name]) { "Cannot find $name" } } + + suspend fun connectClient( + serverName: String, + parentContext: CoroutineContext, + ): RSocketConnection = get(serverName).connectClient(parentContext) + + fun startServer( + serverName: String, + parentContext: CoroutineContext, + connector: LocalServerConnector, + onConnection: (RSocketConnection) -> Unit, + ): LocalServerInstance = LocalServerInstanceImpl( + coroutineContext = parentContext.childContext(), + serverName = serverName, + connector = connector, + onConnection = onConnection + ).also { + register(serverName, it) + } +} + +@RSocketTransportApi +private class LocalServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val serverName: String, + private val connector: LocalServerConnector, + private val onConnection: (RSocketConnection) -> Unit, +) : LocalServerInstance { + private val serverContext = coroutineContext.supervisorContext() + + @RSocketTransportApi + suspend fun connectClient(clientContext: CoroutineContext): RSocketConnection { + coroutineContext.ensureActive() + + return connector.connect( + clientContext = clientContext, + serverContext = serverContext, + onConnection = onConnection + ) + } } diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerTransport.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerTransport.kt index addb8f43..ec8563e6 100644 --- a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerTransport.kt +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerTransport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,20 +19,20 @@ package io.rsocket.kotlin.transport.local import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* import kotlin.coroutines.* -import kotlin.random.* +import kotlin.uuid.* -// TODO: rename to inprocess and more to another module/package later @OptIn(RSocketTransportApi::class) public sealed interface LocalServerInstance : RSocketServerInstance { public val serverName: String } +public typealias LocalServerTarget = RSocketServerTarget + @OptIn(RSocketTransportApi::class) public sealed interface LocalServerTransport : RSocketTransport { - public fun target(): RSocketServerTarget - public fun target(serverName: String): RSocketServerTarget + public fun target(): LocalServerTarget + public fun target(serverName: String): LocalServerTarget public companion object Factory : RSocketTransportFactory(::LocalServerTransportBuilderImpl) @@ -41,20 +41,13 @@ public sealed interface LocalServerTransport : RSocketTransport { @OptIn(RSocketTransportApi::class) public sealed interface LocalServerTransportBuilder : RSocketTransportBuilder { public fun dispatcher(context: CoroutineContext) - public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) - - public fun sequential( - prioritizationQueueBuffersCapacity: Int = Channel.BUFFERED, - ) - public fun multiplexed( - streamsQueueCapacity: Int = Channel.BUFFERED, - streamBufferCapacity: Int = Channel.BUFFERED, - ) + public fun sequential() + public fun multiplexed() } private class LocalServerTransportBuilderImpl : LocalServerTransportBuilder { - private var dispatcher: CoroutineContext = Dispatchers.Default + private var dispatcher: CoroutineContext = Dispatchers.Unconfined private var connector: LocalServerConnector? = null override fun dispatcher(context: CoroutineContext) { @@ -62,18 +55,18 @@ private class LocalServerTransportBuilderImpl : LocalServerTransportBuilder { this.dispatcher = context } - override fun sequential(prioritizationQueueBuffersCapacity: Int) { - connector = LocalServerConnector.Sequential(prioritizationQueueBuffersCapacity) + override fun sequential() { + connector = LocalServerConnector.Sequential } - override fun multiplexed(streamsQueueCapacity: Int, streamBufferCapacity: Int) { - connector = LocalServerConnector.Multiplexed(streamsQueueCapacity, streamBufferCapacity) + override fun multiplexed() { + connector = LocalServerConnector.Multiplexed } @RSocketTransportApi override fun buildTransport(context: CoroutineContext): LocalServerTransport = LocalServerTransportImpl( coroutineContext = context.supervisorContext() + dispatcher, - connector = connector ?: LocalServerConnector.Sequential(Channel.BUFFERED) + connector = connector ?: LocalServerConnector.Sequential ) } @@ -81,16 +74,14 @@ private class LocalServerTransportImpl( override val coroutineContext: CoroutineContext, private val connector: LocalServerConnector, ) : LocalServerTransport { - override fun target(serverName: String): RSocketServerTarget = LocalServerTargetImpl( + override fun target(serverName: String): LocalServerTarget = LocalServerTargetImpl( serverName = serverName, coroutineContext = coroutineContext.supervisorContext(), connector = connector ) - @OptIn(ExperimentalStdlibApi::class) - override fun target(): RSocketServerTarget = target( - Random.nextBytes(16).toHexString(HexFormat.UpperCase) - ) + @OptIn(ExperimentalUuidApi::class) + override fun target(): LocalServerTarget = target(Uuid.random().toString()) } @OptIn(RSocketTransportApi::class) @@ -98,17 +89,12 @@ private class LocalServerTargetImpl( override val coroutineContext: CoroutineContext, private val serverName: String, private val connector: LocalServerConnector, -) : RSocketServerTarget { +) : LocalServerTarget { @RSocketTransportApi - override suspend fun startServer(handler: RSocketConnectionHandler): LocalServerInstance { + override suspend fun startServer(onConnection: (RSocketConnection) -> Unit): LocalServerInstance { currentCoroutineContext().ensureActive() coroutineContext.ensureActive() - return LocalServerInstanceImpl( - serverName = serverName, - coroutineContext = coroutineContext.childContext(), - serverHandler = handler, - connector = connector - ) + return LocalServerRegistry.startServer(serverName, coroutineContext, connector, onConnection) } } diff --git a/rsocket-transports/local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt b/rsocket-transports/local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt index b5122f6b..5ce3d68d 100644 --- a/rsocket-transports/local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt +++ b/rsocket-transports/local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt @@ -17,7 +17,6 @@ package io.rsocket.kotlin.transport.local import io.rsocket.kotlin.transport.tests.* -import kotlinx.coroutines.channels.* @Suppress("DEPRECATION_ERROR") class OldLocalTransportTest : TransportTest() { @@ -36,13 +35,10 @@ abstract class LocalTransportTest( } } -class SequentialBufferedLocalTransportTest : LocalTransportTest({ - sequential(prioritizationQueueBuffersCapacity = Channel.BUFFERED) +class SequentialLocalTransportTest : LocalTransportTest({ + sequential() }) -class MultiplexedBufferedLocalTransportTest : LocalTransportTest({ - multiplexed( - streamsQueueCapacity = Channel.BUFFERED, - streamBufferCapacity = Channel.BUFFERED - ) +class MultiplexedLocalTransportTest : LocalTransportTest({ + multiplexed() }) diff --git a/rsocket-transports/netty-internal/api/rsocket-transport-netty-internal.api b/rsocket-transports/netty-internal/api/rsocket-transport-netty-internal.api index 0fe8a5fc..e88a62d9 100644 --- a/rsocket-transports/netty-internal/api/rsocket-transport-netty-internal.api +++ b/rsocket-transports/netty-internal/api/rsocket-transport-netty-internal.api @@ -1,7 +1,7 @@ public final class io/rsocket/kotlin/transport/netty/internal/CoroutinesKt { - public static final fun awaitChannel (Lio/netty/channel/ChannelFuture;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun awaitFuture (Lio/netty/util/concurrent/Future;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static final fun callOnCancellation (Lkotlinx/coroutines/CoroutineScope;Lkotlin/jvm/functions/Function1;)V + public static final fun join (Lio/netty/util/concurrent/Future;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun shutdownOnCancellation (Lkotlinx/coroutines/CoroutineScope;[Lio/netty/channel/EventLoopGroup;)V } public final class io/rsocket/kotlin/transport/netty/internal/IoKt { diff --git a/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/coroutines.kt b/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/coroutines.kt index f2a84fc6..0f329ca5 100644 --- a/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/coroutines.kt +++ b/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/coroutines.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package io.rsocket.kotlin.transport.netty.internal import io.netty.channel.* import io.netty.util.concurrent.* +import io.rsocket.kotlin.internal.io.* import kotlinx.coroutines.* import kotlin.coroutines.* @@ -34,25 +35,24 @@ public suspend inline fun Future.awaitFuture(): T = suspendCancellableCor } } -public suspend fun ChannelFuture.awaitChannel(): Channel { +public suspend inline fun Future<*>.join(): Unit = suspendCancellableCoroutine { cont -> + addListener { cont.resume(Unit) } + cont.invokeOnCancellation { cancel(true) } +} + +public suspend inline fun ChannelFuture.awaitChannel(): T { awaitFuture() - return channel() + return channel() as T } -// it should be used only for cleanup and so should not really block, only suspend -public inline fun CoroutineScope.callOnCancellation(crossinline block: suspend () -> Unit) { +public fun CoroutineScope.shutdownOnCancellation(vararg groups: EventLoopGroup) { launch(Dispatchers.Unconfined) { try { awaitCancellation() - } catch (cause: Throwable) { - withContext(NonCancellable) { - try { - block() - } catch (suppressed: Throwable) { - cause.addSuppressed(suppressed) - } + } finally { + nonCancellable { + groups.forEach { it.shutdownGracefully().join() } } - throw cause } } } diff --git a/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/io.kt b/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/io.kt index 8758134d..4050e6c2 100644 --- a/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/io.kt +++ b/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/io.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,12 +31,13 @@ public fun ByteBuf.toBuffer(): Buffer { toRead } } + release() return buffer } @OptIn(UnsafeIoApi::class) public fun Buffer.toByteBuf(allocator: ByteBufAllocator): ByteBuf { - val nettyBuffer = allocator.buffer(size.toInt()) // TODO: length + val nettyBuffer = allocator.directBuffer(size.toInt()) // TODO: length while (!exhausted()) { UnsafeBufferOperations.readFromHead(this) { bytes, start, end -> nettyBuffer.writeBytes(bytes, start, end - start) diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport.kt index f6c2cef3..949653cd 100644 --- a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport.kt +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,7 +59,7 @@ private class NettyQuicClientTransportBuilderImpl : NettyQuicClientTransportBuil private var bootstrap: (Bootstrap.() -> Unit)? = null private var codec: (QuicClientCodecBuilder.() -> Unit)? = null private var ssl: (QuicSslContextBuilder.() -> Unit)? = null - private var quicBootstrap: (QuicChannelBootstrap.() -> Unit)? = null + private var quicBootstrap: QuicChannelBootstrap.() -> Unit = { } override fun channel(cls: KClass) { this.channelFactory = ReflectiveChannelFactory(cls.java) @@ -114,24 +114,18 @@ private class NettyQuicClientTransportBuilderImpl : NettyQuicClientTransportBuil return NettyQuicClientTransportImpl( coroutineContext = context.supervisorContext() + bootstrap.config().group().asCoroutineDispatcher(), bootstrap = bootstrap, - quicBootstrap = quicBootstrap, - manageBootstrap = manageEventLoopGroup - ) + quicBootstrap = quicBootstrap + ).also { + if (manageEventLoopGroup) it.shutdownOnCancellation(bootstrap.config().group()) + } } } private class NettyQuicClientTransportImpl( override val coroutineContext: CoroutineContext, private val bootstrap: Bootstrap, - private val quicBootstrap: (QuicChannelBootstrap.() -> Unit)?, - manageBootstrap: Boolean, + private val quicBootstrap: QuicChannelBootstrap.() -> Unit, ) : NettyQuicClientTransport { - init { - if (manageBootstrap) callOnCancellation { - bootstrap.config().group().shutdownGracefully().awaitFuture() - } - } - override fun target(remoteAddress: InetSocketAddress): NettyQuicClientTargetImpl = NettyQuicClientTargetImpl( coroutineContext = coroutineContext.supervisorContext(), bootstrap = bootstrap, @@ -146,14 +140,29 @@ private class NettyQuicClientTransportImpl( private class NettyQuicClientTargetImpl( override val coroutineContext: CoroutineContext, private val bootstrap: Bootstrap, - private val quicBootstrap: (QuicChannelBootstrap.() -> Unit)?, + private val quicBootstrap: QuicChannelBootstrap.() -> Unit, private val remoteAddress: SocketAddress, ) : RSocketClientTarget { @RSocketTransportApi - override fun connectClient(handler: RSocketConnectionHandler): Job = launch { - QuicChannel.newBootstrap(bootstrap.bind().awaitChannel()).also { quicBootstrap?.invoke(it) } + override suspend fun connectClient(): RSocketConnection { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val channel = QuicChannel.newBootstrap( + bootstrap.bind().awaitChannel() + ) + .apply(quicBootstrap) .handler( - NettyQuicConnectionInitializer(handler, coroutineContext, isClient = true) - ).remoteAddress(remoteAddress).connect().awaitFuture() + NettyQuicConnectionInitializer( + parentContext = coroutineContext, + onConnection = null + ) + ) + .streamHandler(NettyQuicStreamInitializer) + .remoteAddress(remoteAddress) + .connect() + .awaitFuture() + + return channel.attr(NettyQuicConnection.ATTRIBUTE).get() } } diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnection.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnection.kt new file mode 100644 index 00000000..086c6322 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnection.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.quic + +import io.netty.channel.* +import io.netty.incubator.codec.quic.* +import io.netty.util.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyQuicConnection( + parentContext: CoroutineContext, + private val channel: QuicChannel, + private val isClient: Boolean, +) : RSocketMultiplexedConnection, ChannelInboundHandlerAdapter() { + override val coroutineContext: CoroutineContext = parentContext.childContext() + channel.eventLoop().asCoroutineDispatcher() + private val streamsContext = coroutineContext.supervisorContext() + + private val inboundStreams = Channel(Channel.UNLIMITED) { + it.cancel("Connection closed") + } + + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + try { + awaitCancellation() + } finally { + nonCancellable { + inboundStreams.cancel() + // stop streams first + streamsContext.job.cancelAndJoin() + channel.close().awaitFuture() + // close UDP channel + if (isClient) channel.parent().close().awaitFuture() + } + } + } + } + + fun initStreamChannel(streamChannel: QuicStreamChannel) { + val stream = NettyQuicStream(streamsContext, streamChannel) + streamChannel.attr(NettyQuicStream.ATTRIBUTE).set(stream) + streamChannel.pipeline().addLast("rsocket-quic-stream", stream) + + if (streamChannel.isLocalCreated) return + + if (inboundStreams.trySend(stream).isFailure) stream.cancel("Connection closed") + } + + override fun channelInactive(ctx: ChannelHandlerContext) { + cancel("Channel is not active") + ctx.fireChannelInactive() + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { + cancel("exceptionCaught", cause) + } + + override suspend fun createStream(): RSocketMultiplexedConnection.Stream { + val streamChannel = channel.createStream(QuicStreamType.BIDIRECTIONAL, NettyQuicStreamInitializer).awaitFuture() + return streamChannel.attr(NettyQuicStream.ATTRIBUTE).get() + } + + override suspend fun acceptStream(): RSocketMultiplexedConnection.Stream? { + return inboundStreams.receiveCatching().getOrNull() + } + + companion object { + val ATTRIBUTE: AttributeKey = AttributeKey.newInstance("rsocket-quic-connection") + } +} + +@RSocketTransportApi +internal class NettyQuicConnectionInitializer( + private val parentContext: CoroutineContext, + private val onConnection: ((RSocketConnection) -> Unit)?, +) : ChannelInitializer() { + override fun initChannel(channel: QuicChannel) { + val connection = NettyQuicConnection( + parentContext = parentContext, + channel = channel, + isClient = onConnection == null + ) + channel.attr(NettyQuicConnection.ATTRIBUTE).set(connection) + + channel.pipeline().addLast( + "rsocket-connection", + connection + ) + + onConnection?.invoke(connection) + } +} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionHandler.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionHandler.kt deleted file mode 100644 index b13ed194..00000000 --- a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionHandler.kt +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.netty.quic - -import io.netty.channel.* -import io.netty.channel.socket.* -import io.netty.incubator.codec.quic.* -import io.rsocket.kotlin.transport.* -import io.rsocket.kotlin.transport.netty.internal.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.channels.Channel -import java.util.concurrent.atomic.* -import kotlin.coroutines.* - -@RSocketTransportApi -internal class NettyQuicConnectionHandler( - private val channel: QuicChannel, - private val handler: RSocketConnectionHandler, - scope: CoroutineScope, - private val isClient: Boolean, -) : ChannelInboundHandlerAdapter() { - private val inbound = Channel(Channel.UNLIMITED) - - private val connectionJob = Job(scope.coroutineContext.job) - private val streamsContext = scope.coroutineContext + SupervisorJob(connectionJob) - - private val handlerJob = scope.launch(connectionJob, start = CoroutineStart.LAZY) { - try { - handler.handleConnection(NettyQuicConnection(channel, inbound, streamsContext, isClient)) - } finally { - inbound.cancel() - withContext(NonCancellable) { - streamsContext.job.cancelAndJoin() - channel.close().awaitFuture() - } - } - } - - override fun channelActive(ctx: ChannelHandlerContext) { - handlerJob.start() - connectionJob.complete() - ctx.pipeline().addLast("rsocket-inbound", NettyQuicConnectionInboundHandler(inbound, streamsContext, isClient)) - - ctx.fireChannelActive() - } - - override fun channelInactive(ctx: ChannelHandlerContext) { - handlerJob.cancel("Channel is not active") - - ctx.fireChannelInactive() - } - - @Suppress("OVERRIDE_DEPRECATION") - override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { - handlerJob.cancel("exceptionCaught", cause) - } -} - -// TODO: implement support for isAutoRead=false to support `inbound` backpressure -@RSocketTransportApi -private class NettyQuicConnectionInboundHandler( - private val inbound: SendChannel, - private val streamsContext: CoroutineContext, - private val isClient: Boolean, -) : ChannelInboundHandlerAdapter() { - // Note: QUIC streams could be received unordered, so f.e we could receive first stream with id 4 and then with id 0 - override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { - msg as QuicStreamChannel - val state = NettyQuicStreamState(null) - if (inbound.trySend(state.wrapStream(msg)).isSuccess) { - msg.pipeline().addLast(NettyQuicStreamInitializer(streamsContext, state, isClient)) - } - ctx.fireChannelRead(msg) - } - - override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) { - if (evt is ChannelInputShutdownEvent) { - inbound.close() - } - super.userEventTriggered(ctx, evt) - } -} - -@RSocketTransportApi -private class NettyQuicConnection( - private val channel: QuicChannel, - private val inbound: ReceiveChannel, - private val streamsContext: CoroutineContext, - private val isClient: Boolean, -) : RSocketMultiplexedConnection { - private val startMarker = Job() - - // we need to `hack` only first stream created for client - stream where frames with streamId=0 will be sent - private val first = AtomicBoolean(isClient) - override suspend fun createStream(): RSocketMultiplexedConnection.Stream { - val startMarker = if (first.getAndSet(false)) { - startMarker - } else { - startMarker.join() - null - } - val state = NettyQuicStreamState(startMarker) - val stream = try { - channel.createStream( - QuicStreamType.BIDIRECTIONAL, - NettyQuicStreamInitializer(streamsContext, state, isClient) - ).awaitFuture() - } catch (cause: Throwable) { - state.closeMarker.complete() - throw cause - } - - return state.wrapStream(stream) - } - - override suspend fun acceptStream(): RSocketMultiplexedConnection.Stream? { - return inbound.receiveCatching().getOrNull() - } -} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionInitializer.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionInitializer.kt deleted file mode 100644 index 5caf07f6..00000000 --- a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionInitializer.kt +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.netty.quic - -import io.netty.channel.* -import io.netty.incubator.codec.quic.* -import io.rsocket.kotlin.transport.* -import kotlinx.coroutines.* -import kotlin.coroutines.* - -@RSocketTransportApi -internal class NettyQuicConnectionInitializer( - private val handler: RSocketConnectionHandler, - override val coroutineContext: CoroutineContext, - private val isClient: Boolean, -) : ChannelInitializer(), CoroutineScope { - override fun initChannel(channel: QuicChannel) { - with(channel.pipeline()) { - //addLast(LoggingHandler(if (isClient) "CLIENT" else "SERVER")) - addLast("rsocket", NettyQuicConnectionHandler(channel, handler, this@NettyQuicConnectionInitializer, isClient)) - } - } -} \ No newline at end of file diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport.kt index 4c01e135..1ca6eefa 100644 --- a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport.kt +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -92,7 +92,7 @@ private class NettyQuicServerTransportBuilderImpl : NettyQuicServerTransportBuil @RSocketTransportApi override fun buildTransport(context: CoroutineContext): NettyQuicServerTransport { - val codecBuilder = QuicServerCodecBuilder().apply { + val codecHandler = fun(connectionInitializer: ChannelHandler): ChannelHandler = QuicServerCodecBuilder().apply { // by default, we allow Int.MAX_VALUE of active stream initialMaxData(Int.MAX_VALUE.toLong()) initialMaxStreamDataBidirectionalLocal(Int.MAX_VALUE.toLong()) @@ -103,7 +103,9 @@ private class NettyQuicServerTransportBuilderImpl : NettyQuicServerTransportBuil val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) sslContext(QuicSslContextBuilder.forServer(keyManagerFactory, null).apply(it).build()) } - } + handler(connectionInitializer) + streamHandler(NettyQuicStreamInitializer) + }.build() val bootstrap = Bootstrap().apply { bootstrap?.invoke(this) @@ -114,28 +116,22 @@ private class NettyQuicServerTransportBuilderImpl : NettyQuicServerTransportBuil return NettyQuicServerTransportImpl( coroutineContext = context.supervisorContext() + bootstrap.config().group().asCoroutineDispatcher(), bootstrap = bootstrap, - codecBuilder = codecBuilder, - manageBootstrap = manageEventLoopGroup - ) + codecHandler = codecHandler + ).also { + if (manageEventLoopGroup) it.shutdownOnCancellation(bootstrap.config().group()) + } } } private class NettyQuicServerTransportImpl( override val coroutineContext: CoroutineContext, private val bootstrap: Bootstrap, - private val codecBuilder: QuicServerCodecBuilder, - manageBootstrap: Boolean, + private val codecHandler: (ChannelHandler) -> ChannelHandler, ) : NettyQuicServerTransport { - init { - if (manageBootstrap) callOnCancellation { - bootstrap.config().group().shutdownGracefully().awaitFuture() - } - } - override fun target(localAddress: InetSocketAddress?): NettyQuicServerTargetImpl = NettyQuicServerTargetImpl( coroutineContext = coroutineContext.supervisorContext(), bootstrap = bootstrap, - codecBuilder = codecBuilder, + codecHandler = codecHandler, localAddress = localAddress ?: InetSocketAddress(0) ) @@ -147,21 +143,27 @@ private class NettyQuicServerTransportImpl( private class NettyQuicServerTargetImpl( override val coroutineContext: CoroutineContext, private val bootstrap: Bootstrap, - private val codecBuilder: QuicServerCodecBuilder, + private val codecHandler: (ChannelHandler) -> ChannelHandler, private val localAddress: SocketAddress, ) : RSocketServerTarget { @RSocketTransportApi - override suspend fun startServer(handler: RSocketConnectionHandler): NettyQuicServerInstance { + override suspend fun startServer(onConnection: (RSocketConnection) -> Unit): NettyQuicServerInstance { currentCoroutineContext().ensureActive() coroutineContext.ensureActive() val instanceContext = coroutineContext.childContext() val channel = try { - bootstrap.clone().handler( - codecBuilder.clone().handler( - NettyQuicConnectionInitializer(handler, instanceContext.supervisorContext(), isClient = false) - ).build() - ).bind(localAddress).awaitChannel() + bootstrap.clone() + .handler( + codecHandler( + NettyQuicConnectionInitializer( + parentContext = instanceContext.supervisorContext(), + onConnection = onConnection + ) + ) + ) + .bind(localAddress) + .awaitChannel() } catch (cause: Throwable) { instanceContext.job.cancel("Failed to bind", cause) throw cause @@ -169,12 +171,27 @@ private class NettyQuicServerTargetImpl( return NettyQuicServerInstanceImpl( coroutineContext = instanceContext, - localAddress = (channel as DatagramChannel).localAddress() as InetSocketAddress + channel = channel ) } } private class NettyQuicServerInstanceImpl( override val coroutineContext: CoroutineContext, - override val localAddress: InetSocketAddress, -) : NettyQuicServerInstance + private val channel: DatagramChannel, +) : NettyQuicServerInstance { + override val localAddress: InetSocketAddress get() = channel.localAddress() as InetSocketAddress + + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + try { + awaitCancellation() + } finally { + nonCancellable { + channel.close().awaitFuture() + } + } + } + } +} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStream.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStream.kt new file mode 100644 index 00000000..5acf3465 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStream.kt @@ -0,0 +1,132 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.quic + +import io.netty.buffer.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.* +import io.netty.incubator.codec.quic.* +import io.netty.util.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.io.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyQuicStream( + parentContext: CoroutineContext, + private val channel: QuicStreamChannel, +) : RSocketMultiplexedConnection.Stream, ChannelInboundHandlerAdapter() { + + private val outbound = bufferChannel(Channel.BUFFERED) + private val inbound = bufferChannel(Channel.UNLIMITED) + + override val coroutineContext: CoroutineContext = parentContext.childContext() + channel.eventLoop().asCoroutineDispatcher() + + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.UNDISPATCHED) { + launch(start = CoroutineStart.UNDISPATCHED) { + nonCancellable { + try { + while (true) { + writeAndFlushBuffer(outbound.receiveCatching().getOrNull() ?: break) + } + } finally { + outbound.cancel() + channel.shutdownOutput().awaitFuture() + } + } + } + try { + awaitCancellation() + } finally { + outbound.close() + inbound.cancel() + } + } + } + + override fun channelInactive(ctx: ChannelHandlerContext) { + cancel("Channel is not active") + ctx.fireChannelInactive() + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { + cancel("exceptionCaught", cause) + } + + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any?) { + if (evt === ChannelInputShutdownEvent.INSTANCE) inbound.close() + super.userEventTriggered(ctx, evt) + } + + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + val buffer = (msg as ByteBuf).toBuffer() + if (inbound.trySend(buffer).isFailure) buffer.clear() + } + + override fun setSendPriority(priority: Int) { + channel.updatePriority(QuicStreamPriority(priority, false)) + } + + override suspend fun sendFrame(frame: Buffer) { + outbound.send(frame) + } + + override suspend fun receiveFrame(): Buffer? { + return inbound.receiveCatching().getOrNull() + } + + private suspend fun writeAndFlushBuffer(buffer: Buffer) { + channel.writeAndFlush(buffer.toByteBuf(channel.alloc())).awaitFuture() + } + + companion object { + val ATTRIBUTE: AttributeKey = AttributeKey.newInstance("rsocket-quic-stream") + } +} + +@RSocketTransportApi +internal object NettyQuicStreamInitializer : ChannelInitializer() { + override fun initChannel(channel: QuicStreamChannel) { + channel.pipeline().addLast( + "rsocket-length-encoder", + LengthFieldPrepender( + /* lengthFieldLength = */ 3 + ) + ) + channel.pipeline().addLast( + "rsocket-length-decoder", + LengthFieldBasedFrameDecoder( + /* maxFrameLength = */ Int.MAX_VALUE, + /* lengthFieldOffset = */ 0, + /* lengthFieldLength = */ 3, + /* lengthAdjustment = */ 0, + /* initialBytesToStrip = */ 3 + ) + ) + + channel.parent() + .attr(NettyQuicConnection.ATTRIBUTE).get() + .initStreamChannel(channel) + } +} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamHandler.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamHandler.kt deleted file mode 100644 index b2b55e61..00000000 --- a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamHandler.kt +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.netty.quic - -import io.netty.buffer.* -import io.netty.channel.* -import io.netty.channel.socket.* -import io.netty.incubator.codec.quic.* -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.transport.* -import io.rsocket.kotlin.transport.netty.internal.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.channels.Channel -import kotlinx.io.* - -// TODO: first stream is a hack to initiate first stream because of buffering -// quic streams could be received unordered by server, so f.e we could receive first stream with id 4 and then with id 0 -// for this, we disable buffering for first client stream, so that first frame will be sent first -// this will affect performance for this stream, so we need to do something else here. -@RSocketTransportApi -internal class NettyQuicStreamState(val startMarker: CompletableJob?) { - val closeMarker: CompletableJob = Job() - val outbound = bufferChannel(Channel.BUFFERED) - val inbound = bufferChannel(Channel.UNLIMITED) - - fun wrapStream(stream: QuicStreamChannel): RSocketMultiplexedConnection.Stream = - NettyQuicStream(stream, outbound, inbound, closeMarker) -} - -@RSocketTransportApi -internal class NettyQuicStreamHandler( - private val channel: QuicStreamChannel, - scope: CoroutineScope, - private val state: NettyQuicStreamState, - private val isClient: Boolean, -) : ChannelInboundHandlerAdapter() { - private val handlerJob = scope.launch(start = CoroutineStart.LAZY) { - val outbound = state.outbound - - val writerJob = launch(start = CoroutineStart.UNDISPATCHED) { - try { - while (true) { - // we write all available frames here, and only after it flush - // in this case, if there are several buffered frames we can send them in one go - // avoiding unnecessary flushes - // TODO: could be optimized to avoid allocation of not-needed promises - - var lastWriteFuture = channel.write(outbound.receiveCatching().getOrNull()?.toByteBuf(channel.alloc()) ?: break) - while (true) lastWriteFuture = channel.write(outbound.tryReceive().getOrNull()?.toByteBuf(channel.alloc()) ?: break) - //println("FLUSH: $isClient: ${channel.streamId()}") - channel.flush() - // await writing to respect transport backpressure - lastWriteFuture.awaitFuture() - state.startMarker?.complete() - } - } finally { - withContext(NonCancellable) { - channel.shutdownOutput().awaitFuture() - } - } - }.onCompletion { outbound.cancel() } - - try { - state.closeMarker.join() - } finally { - outbound.close() // will cause `writerJob` completion - // no more reading - state.inbound.cancel() - withContext(NonCancellable) { - writerJob.join() - channel.close().awaitFuture() - } - } - } - - override fun channelActive(ctx: ChannelHandlerContext) { - handlerJob.start() - ctx.pipeline().addLast("rsocket-inbound", NettyQuicStreamInboundHandler(state.inbound)) - - ctx.fireChannelActive() - } - - override fun channelInactive(ctx: ChannelHandlerContext) { - handlerJob.cancel("Channel is not active") - - ctx.fireChannelInactive() - } - - @Suppress("OVERRIDE_DEPRECATION") - override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { - handlerJob.cancel("exceptionCaught", cause) - } -} - -private class NettyQuicStreamInboundHandler( - private val inbound: SendChannel, -) : ChannelInboundHandlerAdapter() { - override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { - msg as ByteBuf - try { - val frame = msg.toBuffer() - if (inbound.trySend(frame).isFailure) { - frame.close() - } - } finally { - msg.release() - } - } - - override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) { - if (evt is ChannelInputShutdownEvent) { - inbound.close() - } - super.userEventTriggered(ctx, evt) - } -} - -@RSocketTransportApi -private class NettyQuicStream( - // for priority - private val stream: QuicStreamChannel, - private val outbound: SendChannel, - private val inbound: ReceiveChannel, - private val closeMarker: CompletableJob, -) : RSocketMultiplexedConnection.Stream { - - @OptIn(DelicateCoroutinesApi::class) - override val isClosedForSend: Boolean get() = outbound.isClosedForSend - - override fun setSendPriority(priority: Int) { - stream.updatePriority(QuicStreamPriority(priority, false)) - } - - override suspend fun sendFrame(frame: Buffer) { - outbound.send(frame) - } - - override suspend fun receiveFrame(): Buffer? { - return inbound.receiveCatching().getOrNull() - } - - override fun close() { - closeMarker.complete() - } -} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamInitializer.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamInitializer.kt deleted file mode 100644 index 68c2f7a6..00000000 --- a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamInitializer.kt +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.netty.quic - -import io.netty.channel.* -import io.netty.handler.codec.* -import io.netty.incubator.codec.quic.* -import io.rsocket.kotlin.transport.* -import kotlinx.coroutines.* -import kotlin.coroutines.* - -@RSocketTransportApi -internal class NettyQuicStreamInitializer( - override val coroutineContext: CoroutineContext, - private val state: NettyQuicStreamState, - private val isClient: Boolean, -) : ChannelInitializer(), CoroutineScope { - override fun initChannel(channel: QuicStreamChannel): Unit = with(channel.pipeline()) { - addLast( - "rsocket-length-encoder", - LengthFieldPrepender( - /* lengthFieldLength = */ 3 - ) - ) - addLast( - "rsocket-length-decoder", - LengthFieldBasedFrameDecoder( - /* maxFrameLength = */ Int.MAX_VALUE, - /* lengthFieldOffset = */ 0, - /* lengthFieldLength = */ 3, - /* lengthAdjustment = */ 0, - /* initialBytesToStrip = */ 3 - ) - ) - addLast("rsocket", NettyQuicStreamHandler(channel, this@NettyQuicStreamInitializer, state, isClient)) - } -} diff --git a/rsocket-transports/netty-quic/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicTransportTest.kt b/rsocket-transports/netty-quic/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicTransportTest.kt index 18774a16..ade39d00 100644 --- a/rsocket-transports/netty-quic/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicTransportTest.kt +++ b/rsocket-transports/netty-quic/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicTransportTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ package io.rsocket.kotlin.transport.netty.quic import io.netty.channel.nio.* import io.netty.handler.ssl.util.* -import io.netty.incubator.codec.quic.* import io.rsocket.kotlin.transport.tests.* import kotlin.concurrent.* @@ -28,7 +27,7 @@ private val eventLoop = NioEventLoopGroup().also { } private val certificates = SelfSignedCertificate() -private val protos = arrayOf("hq-29") +private val protos = arrayOf("h3") class NettyQuicTransportTest : TransportTest() { override suspend fun before() { @@ -39,9 +38,6 @@ class NettyQuicTransportTest : TransportTest() { keyManager(certificates.privateKey(), null, certificates.certificate()) applicationProtocols(*protos) } - codec { - tokenHandler(InsecureQuicTokenHandler.INSTANCE) - } }.target("127.0.0.1") ) client = connectClient( diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt index 9675af6b..74423a92 100644 --- a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -91,31 +91,25 @@ private class NettyTcpClientTransportBuilderImpl : NettyTcpClientTransportBuilde } return NettyTcpClientTransportImpl( - coroutineContext = context.supervisorContext() + bootstrap.config().group().asCoroutineDispatcher(), - sslContext = sslContext, + coroutineContext = context.supervisorContext() + Dispatchers.Default, bootstrap = bootstrap, - manageBootstrap = manageEventLoopGroup - ) + sslContext = sslContext, + ).also { + if (manageEventLoopGroup) it.shutdownOnCancellation(bootstrap.config().group()) + } } } private class NettyTcpClientTransportImpl( override val coroutineContext: CoroutineContext, - private val sslContext: SslContext?, private val bootstrap: Bootstrap, - manageBootstrap: Boolean, + private val sslContext: SslContext?, ) : NettyTcpClientTransport { - init { - if (manageBootstrap) callOnCancellation { - bootstrap.config().group().shutdownGracefully().awaitFuture() - } - } - - override fun target(remoteAddress: SocketAddress): NettyTcpClientTargetImpl = NettyTcpClientTargetImpl( + override fun target(remoteAddress: SocketAddress): RSocketClientTarget = NettyTcpClientTargetImpl( coroutineContext = coroutineContext.supervisorContext(), bootstrap = bootstrap, sslContext = sslContext, - remoteAddress = remoteAddress + remoteAddress = remoteAddress, ) override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port)) @@ -124,19 +118,27 @@ private class NettyTcpClientTransportImpl( @OptIn(RSocketTransportApi::class) private class NettyTcpClientTargetImpl( override val coroutineContext: CoroutineContext, - private val bootstrap: Bootstrap, - private val sslContext: SslContext?, - private val remoteAddress: SocketAddress, + bootstrap: Bootstrap, + sslContext: SslContext?, + remoteAddress: SocketAddress, ) : RSocketClientTarget { - @RSocketTransportApi - override fun connectClient(handler: RSocketConnectionHandler): Job = launch { - bootstrap.clone().handler( + private val bootstrap = bootstrap.clone() + .handler( NettyTcpConnectionInitializer( + parentContext = coroutineContext, sslContext = sslContext, - remoteAddress = remoteAddress as? InetSocketAddress, - handler = handler, - coroutineContext = coroutineContext + onConnection = null ) - ).connect(remoteAddress).awaitFuture() + ) + .remoteAddress(remoteAddress) + + @RSocketTransportApi + override suspend fun connectClient(): RSocketConnection { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val channel = bootstrap.connect().awaitChannel() + + return channel.attr(NettyTcpConnection.ATTRIBUTE).get() } } diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnection.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnection.kt new file mode 100644 index 00000000..44e0b74a --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnection.kt @@ -0,0 +1,156 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.buffer.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.* +import io.netty.handler.ssl.* +import io.netty.util.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel +import kotlinx.io.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyTcpConnection( + parentContext: CoroutineContext, + private val channel: DuplexChannel, +) : RSocketSequentialConnection, ChannelInboundHandlerAdapter() { + + private val outboundQueue = PrioritizationFrameQueue() + private val inbound = bufferChannel(Channel.UNLIMITED) + + override val coroutineContext: CoroutineContext = parentContext.childContext() + channel.eventLoop().asCoroutineDispatcher() + + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + val outboundJob = launch(start = CoroutineStart.ATOMIC) { + nonCancellable { + try { + while (true) { + // we write all available frames here, and only after it flush + // in this case, if there are several buffered frames we can send them in one go + // avoiding unnecessary flushes + writeBuffer(outboundQueue.dequeueFrame() ?: break) + while (true) writeBuffer(outboundQueue.tryDequeueFrame() ?: break) + channel.flush() + } + } finally { + outboundQueue.cancel() + channel.shutdownOutput().awaitFuture() + } + } + } + try { + awaitCancellation() + } finally { + nonCancellable { + outboundQueue.close() + inbound.cancel() + channel.shutdownInput().awaitFuture() + outboundJob.join() + channel.close().awaitFuture() + } + } + } + } + + override fun channelInactive(ctx: ChannelHandlerContext) { + cancel("Channel is not active") + ctx.fireChannelInactive() + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { + cancel("exceptionCaught", cause) + } + + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any?) { + if (evt === ChannelInputShutdownEvent.INSTANCE) inbound.close() + super.userEventTriggered(ctx, evt) + } + + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + val buffer = (msg as ByteBuf).toBuffer() + if (inbound.trySend(buffer).isFailure) buffer.clear() + } + + override suspend fun sendFrame(streamId: Int, frame: Buffer) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): Buffer? { + inbound.tryReceive().onSuccess { return it } + channel.read() + return inbound.receiveCatching().getOrNull() + } + + private fun writeBuffer(buffer: Buffer) { + channel.write(buffer.toByteBuf(channel.alloc()), channel.voidPromise()) + } + + companion object { + val ATTRIBUTE: AttributeKey = AttributeKey.newInstance("rsocket-tcp-connection") + } +} + +@OptIn(RSocketTransportApi::class) +internal class NettyTcpConnectionInitializer( + private val parentContext: CoroutineContext, + private val sslContext: SslContext?, + private val onConnection: ((RSocketConnection) -> Unit)?, +) : ChannelInitializer() { + override fun initChannel(channel: DuplexChannel) { + channel.config().isAutoRead = false + + val connection = NettyTcpConnection(parentContext, channel) + channel.attr(NettyTcpConnection.ATTRIBUTE).set(connection) + + if (sslContext != null) { + channel.pipeline().addLast("ssl", sslContext.newHandler(channel.alloc())) + } + channel.pipeline().addLast( + "rsocket-length-encoder", + LengthFieldPrepender( + /* lengthFieldLength = */ 3 + ) + ) + channel.pipeline().addLast( + "rsocket-length-decoder", + LengthFieldBasedFrameDecoder( + /* maxFrameLength = */ Int.MAX_VALUE, + /* lengthFieldOffset = */ 0, + /* lengthFieldLength = */ 3, + /* lengthAdjustment = */ 0, + /* initialBytesToStrip = */ 3 + ) + ) + channel.pipeline().addLast( + "rsocket-connection", + connection + ) + + onConnection?.invoke(connection) + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionHandler.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionHandler.kt deleted file mode 100644 index 07cbaabf..00000000 --- a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionHandler.kt +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.netty.tcp - -import io.netty.buffer.* -import io.netty.channel.* -import io.netty.channel.socket.* -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.transport.* -import io.rsocket.kotlin.transport.internal.* -import io.rsocket.kotlin.transport.netty.internal.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.channels.Channel -import kotlinx.io.* -import io.netty.channel.socket.DuplexChannel as NettyDuplexChannel - -@RSocketTransportApi -internal class NettyTcpConnectionHandler( - private val channel: NettyDuplexChannel, - private val handler: RSocketConnectionHandler, - scope: CoroutineScope, -) : ChannelInboundHandlerAdapter() { - private val inbound = bufferChannel(Channel.UNLIMITED) - - private val handlerJob = scope.launch(start = CoroutineStart.LAZY) { - val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) - - val writerJob = launch { - try { - while (true) { - // we write all available frames here, and only after it flush - // in this case, if there are several buffered frames we can send them in one go - // avoiding unnecessary flushes - // TODO: could be optimized to avoid allocation of not-needed promises - var lastWriteFuture = channel.write(outboundQueue.dequeueFrame()?.toByteBuf(channel.alloc()) ?: break) - while (true) lastWriteFuture = channel.write(outboundQueue.tryDequeueFrame()?.toByteBuf(channel.alloc()) ?: break) - channel.flush() - // await writing to respect transport backpressure - lastWriteFuture.awaitFuture() - } - } finally { - withContext(NonCancellable) { - channel.shutdownOutput().awaitFuture() - } - } - }.onCompletion { outboundQueue.cancel() } - - try { - handler.handleConnection(NettyTcpConnection(outboundQueue, inbound)) - } finally { - outboundQueue.close() // will cause `writerJob` completion - // no more reading - inbound.cancel() - withContext(NonCancellable) { - writerJob.join() - channel.close().awaitFuture() - } - } - } - - override fun channelActive(ctx: ChannelHandlerContext) { - handlerJob.start() - ctx.pipeline().addLast("rsocket-inbound", NettyTcpConnectionInboundHandler(inbound)) - - ctx.fireChannelActive() - } - - override fun channelInactive(ctx: ChannelHandlerContext) { - handlerJob.cancel("Channel is not active") - - ctx.fireChannelInactive() - } - - @Suppress("OVERRIDE_DEPRECATION") - override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { - handlerJob.cancel("exceptionCaught", cause) - } -} - -// TODO: implement support for isAutoRead=false to support `inbound` backpressure -private class NettyTcpConnectionInboundHandler( - private val inbound: SendChannel, -) : ChannelInboundHandlerAdapter() { - override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { - msg as ByteBuf - try { - val frame = msg.toBuffer() - if (inbound.trySend(frame).isFailure) { - frame.clear() - error("inbound is closed") - } - } finally { - msg.release() - } - } - - override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) { - if (evt is ChannelInputShutdownEvent) { - inbound.close() - } - super.userEventTriggered(ctx, evt) - } -} - -@RSocketTransportApi -private class NettyTcpConnection( - private val outboundQueue: PrioritizationFrameQueue, - private val inbound: ReceiveChannel, -) : RSocketSequentialConnection { - override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend - override suspend fun sendFrame(streamId: Int, frame: Buffer) { - return outboundQueue.enqueueFrame(streamId, frame) - } - - override suspend fun receiveFrame(): Buffer? { - return inbound.receiveCatching().getOrNull() - } -} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionInitializer.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionInitializer.kt deleted file mode 100644 index 2d040061..00000000 --- a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionInitializer.kt +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.netty.tcp - -import io.netty.channel.* -import io.netty.channel.socket.* -import io.netty.handler.codec.* -import io.netty.handler.ssl.* -import io.rsocket.kotlin.transport.* -import kotlinx.coroutines.* -import java.net.* -import kotlin.coroutines.* - -@RSocketTransportApi -internal class NettyTcpConnectionInitializer( - private val sslContext: SslContext?, - private val remoteAddress: InetSocketAddress?, - private val handler: RSocketConnectionHandler, - override val coroutineContext: CoroutineContext, -) : ChannelInitializer(), CoroutineScope { - override fun initChannel(channel: DuplexChannel): Unit = with(channel.pipeline()) { - if (sslContext != null) { - addLast( - "ssl", - when { - remoteAddress != null -> sslContext.newHandler(channel.alloc(), remoteAddress.hostName, remoteAddress.port) - else -> sslContext.newHandler(channel.alloc()) - } - ) - } - addLast( - "rsocket-length-encoder", - LengthFieldPrepender( - /* lengthFieldLength = */ 3 - ) - ) - addLast( - "rsocket-length-decoder", - LengthFieldBasedFrameDecoder( - /* maxFrameLength = */ kotlin.Int.MAX_VALUE, - /* lengthFieldOffset = */ 0, - /* lengthFieldLength = */ 3, - /* lengthAdjustment = */ 0, - /* initialBytesToStrip = */ 3 - ) - ) - addLast("rsocket", NettyTcpConnectionHandler(channel, handler, this@NettyTcpConnectionInitializer)) - } -} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt index 9566f451..18078e35 100644 --- a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2024 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,10 +36,12 @@ public sealed interface NettyTcpServerInstance : RSocketServerInstance { public val localAddress: SocketAddress } +public typealias NettyTcpServerTarget = RSocketServerTarget + @OptIn(RSocketTransportApi::class) public sealed interface NettyTcpServerTransport : RSocketTransport { - public fun target(localAddress: SocketAddress? = null): RSocketServerTarget - public fun target(host: String = "0.0.0.0", port: Int = 0): RSocketServerTarget + public fun target(localAddress: SocketAddress? = null): NettyTcpServerTarget + public fun target(host: String = "0.0.0.0", port: Int = 0): NettyTcpServerTarget public companion object Factory : RSocketTransportFactory(::NettyTcpServerTransportBuilderImpl) @@ -101,15 +103,19 @@ private class NettyTcpServerTransportBuilderImpl : NettyTcpServerTransportBuilde val bootstrap = ServerBootstrap().apply { bootstrap?.invoke(this) channelFactory(channelFactory ?: ReflectiveChannelFactory(NioServerSocketChannel::class.java)) - group(parentEventLoopGroup ?: NioEventLoopGroup(), childEventLoopGroup ?: NioEventLoopGroup()) + + val parentEventLoopGroup = parentEventLoopGroup ?: NioEventLoopGroup() + val childEventLoopGroup = childEventLoopGroup ?: parentEventLoopGroup + group(parentEventLoopGroup, childEventLoopGroup) } return NettyTcpServerTransportImpl( - coroutineContext = context.supervisorContext() + bootstrap.config().childGroup().asCoroutineDispatcher(), + coroutineContext = context.supervisorContext() + Dispatchers.Default, bootstrap = bootstrap, sslContext = sslContext, - manageBootstrap = manageEventLoopGroup - ) + ).also { + if (manageEventLoopGroup) it.shutdownOnCancellation(bootstrap.config().childGroup(), bootstrap.config().group()) + } } } @@ -117,24 +123,15 @@ private class NettyTcpServerTransportImpl( override val coroutineContext: CoroutineContext, private val bootstrap: ServerBootstrap, private val sslContext: SslContext?, - manageBootstrap: Boolean, ) : NettyTcpServerTransport { - init { - if (manageBootstrap) callOnCancellation { - bootstrap.config().childGroup().shutdownGracefully().awaitFuture() - bootstrap.config().group().shutdownGracefully().awaitFuture() - } - } - - override fun target(localAddress: SocketAddress?): NettyTcpServerTargetImpl = NettyTcpServerTargetImpl( + override fun target(localAddress: SocketAddress?): NettyTcpServerTarget = NettyTcpServerTargetImpl( coroutineContext = coroutineContext.supervisorContext(), bootstrap = bootstrap, sslContext = sslContext, localAddress = localAddress ?: InetSocketAddress(0), ) - override fun target(host: String, port: Int): RSocketServerTarget = - target(InetSocketAddress(host, port)) + override fun target(host: String, port: Int): NettyTcpServerTarget = target(InetSocketAddress(host, port)) } @OptIn(RSocketTransportApi::class) @@ -143,36 +140,49 @@ private class NettyTcpServerTargetImpl( private val bootstrap: ServerBootstrap, private val sslContext: SslContext?, private val localAddress: SocketAddress, -) : RSocketServerTarget { +) : NettyTcpServerTarget { @RSocketTransportApi - override suspend fun startServer(handler: RSocketConnectionHandler): NettyTcpServerInstance { + override suspend fun startServer(onConnection: (RSocketConnection) -> Unit): NettyTcpServerInstance { currentCoroutineContext().ensureActive() coroutineContext.ensureActive() val instanceContext = coroutineContext.childContext() val channel = try { - bootstrap.clone().childHandler( - NettyTcpConnectionInitializer( - sslContext = sslContext, - remoteAddress = null, - handler = handler, - coroutineContext = instanceContext.supervisorContext() - ) - ).bind(localAddress).awaitChannel() + val handler = NettyTcpConnectionInitializer( + parentContext = instanceContext.supervisorContext(), + sslContext = sslContext, + onConnection = onConnection, + ) + bootstrap.clone() + .childHandler(handler) + .bind(localAddress) + .awaitChannel() } catch (cause: Throwable) { instanceContext.job.cancel("Failed to bind", cause) throw cause } - // TODO: handle server closure - return NettyTcpServerInstanceImpl( - coroutineContext = instanceContext, - localAddress = (channel as ServerChannel).localAddress() - ) + return NettyTcpServerInstanceImpl(instanceContext, channel) } } +@RSocketTransportApi private class NettyTcpServerInstanceImpl( override val coroutineContext: CoroutineContext, - override val localAddress: SocketAddress, -) : NettyTcpServerInstance + private val channel: ServerChannel, +) : NettyTcpServerInstance { + override val localAddress: SocketAddress get() = channel.localAddress() + + init { + @OptIn(DelicateCoroutinesApi::class) + launch(start = CoroutineStart.ATOMIC) { + try { + awaitCancellation() + } finally { + nonCancellable { + channel.close().awaitFuture() + } + } + } + } +} diff --git a/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt b/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt index c81983f7..bc219bcb 100644 --- a/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt +++ b/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt @@ -18,8 +18,10 @@ package io.rsocket.kotlin.transport.nodejs.tcp import io.rsocket.kotlin.transport.tests.* import kotlinx.coroutines.* +import kotlin.test.* @Suppress("DEPRECATION_ERROR") +@Ignore class TcpTransportTest : TransportTest() { private lateinit var server: TcpServer