diff --git a/Package.swift b/Package.swift index 7cfbdb86..4a83225b 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.2 +// swift-tools-version:5.3 import PackageDescription let package = Package( @@ -12,6 +12,7 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.11.1"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.5.1"), ], targets: [ .target(name: "WebSocketKit", dependencies: [ @@ -21,6 +22,13 @@ let package = Package( .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOWebSocket", package: "swift-nio"), + .product( + name: "NIOTransportServices", + package: "swift-nio-transport-services", + condition: .when( + platforms: [Platform.iOS, Platform.macOS, Platform.tvOS, Platform.watchOS] + ) + ) ]), .testTarget(name: "WebSocketKitTests", dependencies: [ .target(name: "WebSocketKit"), diff --git a/Package@swift-5.2.swift b/Package@swift-5.2.swift new file mode 100644 index 00000000..ac5bf440 --- /dev/null +++ b/Package@swift-5.2.swift @@ -0,0 +1,29 @@ +// swift-tools-version:5.2 +import PackageDescription + +let package = Package( + name: "websocket-kit", + platforms: [ + .macOS(.v10_15) + ], + products: [ + .library(name: "WebSocketKit", targets: ["WebSocketKit"]), + ], + dependencies: [ + .package(url: "https://github.com/apple/swift-nio.git", from: "2.11.1"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"), + ], + targets: [ + .target(name: "WebSocketKit", dependencies: [ + .product(name: "NIO", package: "swift-nio"), + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + .product(name: "NIOFoundationCompat", package: "swift-nio"), + .product(name: "NIOHTTP1", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "NIOWebSocket", package: "swift-nio") + ]), + .testTarget(name: "WebSocketKitTests", dependencies: [ + .target(name: "WebSocketKit"), + ]), + ] +) diff --git a/Sources/WebSocketKit/HTTPChannelIntercepter.swift b/Sources/WebSocketKit/HTTPChannelIntercepter.swift new file mode 100644 index 00000000..4de32b26 --- /dev/null +++ b/Sources/WebSocketKit/HTTPChannelIntercepter.swift @@ -0,0 +1,57 @@ +import Foundation +import NIOHTTP1 +import NIO + +public final class HTTPChannelIntercepter: ChannelDuplexHandler, RemovableChannelHandler { + + public typealias OutboundIn = HTTPClientRequestPart + public typealias OutboundOut = HTTPClientRequestPart + + public typealias InboundIn = HTTPClientResponsePart + public typealias InboundOut = HTTPClientResponsePart + + let writeInterceptHandler: (HTTPRequestHead) -> Void + + public init(writeInterceptHandler: @escaping (HTTPRequestHead) -> Void) { + self.writeInterceptHandler = writeInterceptHandler + } + + public func write( + context: ChannelHandlerContext, + data: NIOAny, + promise: EventLoopPromise? + ) { + let interceptedOutgoingRequest = self.unwrapOutboundIn(data) + + if case .head(let requestHead) = interceptedOutgoingRequest { + self.writeInterceptHandler(requestHead) + } + + context.write(data, promise: promise) + } +} + +extension ChannelPipeline { + public func addHTTPClientHandlers( + position: Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + withServerUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil, + withExtraHandlers extraHandlers: [RemovableChannelHandler] = [] + ) -> EventLoopFuture { + let requestEncoder = HTTPRequestEncoder() + let responseDecoder = HTTPResponseDecoder(leftOverBytesStrategy: leftOverBytesStrategy) + + var handlers: [RemovableChannelHandler] = [requestEncoder, ByteToMessageHandler(responseDecoder)] + extraHandlers + + if let (upgraders, completionHandler) = upgrade { + let upgrader = NIOHTTPClientUpgradeHandler( + upgraders: upgraders, + httpHandlers: handlers, + upgradeCompletionHandler: completionHandler + ) + handlers.append(upgrader) + } + + return self.addHandlers(handlers, position: position) + } +} diff --git a/Sources/WebSocketKit/TLSConfiguration+Network.swift b/Sources/WebSocketKit/TLSConfiguration+Network.swift new file mode 100644 index 00000000..87921b08 --- /dev/null +++ b/Sources/WebSocketKit/TLSConfiguration+Network.swift @@ -0,0 +1,141 @@ +//===----------------------------------------------------------------------===// +// +// This source file was part of the AsyncHTTPClient open source project +// https://github.com/swift-server/async-http-client +// +// Copyright (c) 2020 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Network) && swift(>=5.3) + + import Foundation + import Network + import NIOSSL + import NIOTransportServices + + extension TLSVersion { + /// return Network framework TLS protocol version + var nwTLSProtocolVersion: tls_protocol_version_t { + switch self { + case .tlsv1: + return .TLSv10 + case .tlsv11: + return .TLSv11 + case .tlsv12: + return .TLSv12 + case .tlsv13: + return .TLSv13 + } + } + } + + extension TLSVersion { + /// return as SSL protocol + var sslProtocol: SSLProtocol { + switch self { + case .tlsv1: + return .tlsProtocol1 + case .tlsv11: + return .tlsProtocol11 + case .tlsv12: + return .tlsProtocol12 + case .tlsv13: + return .tlsProtocol13 + } + } + } + + @available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) + extension TLSConfiguration { + /// Dispatch queue used by Network framework TLS to control certificate verification + static var tlsDispatchQueue = DispatchQueue(label: "TLSDispatch") + + /// create NWProtocolTLS.Options for use with NIOTransportServices from the NIOSSL TLSConfiguration + /// + /// - Parameter queue: Dispatch queue to run `sec_protocol_options_set_verify_block` on. + /// - Returns: Equivalent NWProtocolTLS Options + func getNWProtocolTLSOptions() -> NWProtocolTLS.Options { + let options = NWProtocolTLS.Options() + + // minimum TLS protocol + if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { + sec_protocol_options_set_min_tls_protocol_version(options.securityProtocolOptions, self.minimumTLSVersion.nwTLSProtocolVersion) + } else { + sec_protocol_options_set_tls_min_version(options.securityProtocolOptions, self.minimumTLSVersion.sslProtocol) + } + + // maximum TLS protocol + if let maximumTLSVersion = self.maximumTLSVersion { + if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { + sec_protocol_options_set_max_tls_protocol_version(options.securityProtocolOptions, maximumTLSVersion.nwTLSProtocolVersion) + } else { + sec_protocol_options_set_tls_max_version(options.securityProtocolOptions, maximumTLSVersion.sslProtocol) + } + } + + // application protocols + for applicationProtocol in self.applicationProtocols { + applicationProtocol.withCString { buffer in + sec_protocol_options_add_tls_application_protocol(options.securityProtocolOptions, buffer) + } + } + + // the certificate chain + if self.certificateChain.count > 0 { + preconditionFailure("TLSConfiguration.certificateChain is not supported") + } + + // cipher suites + if self.cipherSuites.count > 0 { + // TODO: Requires NIOSSL to provide list of cipher values before we can continue + // https://github.com/apple/swift-nio-ssl/issues/207 + } + + // key log callback + if self.keyLogCallback != nil { + preconditionFailure("TLSConfiguration.keyLogCallback is not supported") + } + + // private key + if self.privateKey != nil { + preconditionFailure("TLSConfiguration.privateKey is not supported") + } + + // renegotiation support key is unsupported + + // trust roots + if let trustRoots = self.trustRoots { + guard case .default = trustRoots else { + preconditionFailure("TLSConfiguration.trustRoots != .default is not supported") + } + } + + switch self.certificateVerification { + case .none: + // add verify block to control certificate verification + sec_protocol_options_set_verify_block( + options.securityProtocolOptions, + { _, _, sec_protocol_verify_complete in + sec_protocol_verify_complete(true) + }, TLSConfiguration.tlsDispatchQueue + ) + + case .noHostnameVerification: + precondition(self.certificateVerification != .noHostnameVerification, "TLSConfiguration.certificateVerification = .noHostnameVerification is not supported") + + case .fullVerification: + break + } + + return options + } + } + +#endif diff --git a/Sources/WebSocketKit/WebSocket+Connect.swift b/Sources/WebSocketKit/WebSocket+Connect.swift index ec10393b..4ab10c12 100644 --- a/Sources/WebSocketKit/WebSocket+Connect.swift +++ b/Sources/WebSocketKit/WebSocket+Connect.swift @@ -1,10 +1,12 @@ +import NIOHTTP1 + extension WebSocket { public static func connect( to url: String, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket, HTTPResponseHead) -> () ) -> EventLoopFuture { guard let url = URL(string: url) else { return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL) @@ -17,13 +19,28 @@ extension WebSocket { onUpgrade: onUpgrade ) } + + public static func connect( + to url: String, + headers: HTTPHeaders = [:], + configuration: WebSocketClient.Configuration = .init(), + on eventLoopGroup: EventLoopGroup, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return self.connect( + to: url, + headers: headers, + configuration: configuration, + on: eventLoopGroup + ) { (ws, _ ) in onUpgrade(ws) } + } public static func connect( to url: URL, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket, HTTPResponseHead) -> () ) -> EventLoopFuture { let scheme = url.scheme ?? "ws" return self.connect( @@ -37,6 +54,21 @@ extension WebSocket { onUpgrade: onUpgrade ) } + + public static func connect( + to url: URL, + headers: HTTPHeaders = [:], + configuration: WebSocketClient.Configuration = .init(), + on eventLoopGroup: EventLoopGroup, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return self.connect( + to: url, + headers: headers, + configuration: configuration, + on: eventLoopGroup + ) { (ws, _) in onUpgrade(ws) } + } public static func connect( scheme: String = "ws", @@ -46,7 +78,7 @@ extension WebSocket { headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket, HTTPResponseHead) -> () ) -> EventLoopFuture { return WebSocketClient( eventLoopGroupProvider: .shared(eventLoopGroup), @@ -60,4 +92,25 @@ extension WebSocket { onUpgrade: onUpgrade ) } + + public static func connect( + scheme: String = "ws", + host: String, + port: Int = 80, + path: String = "/", + headers: HTTPHeaders = [:], + configuration: WebSocketClient.Configuration = .init(), + on eventLoopGroup: EventLoopGroup, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return self.connect( + scheme: scheme, + host: host, + port: port, + path: path, + headers: headers, + configuration: configuration, + on: eventLoopGroup + ) { (ws, _) in onUpgrade(ws) } + } } diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift index 1eb4df78..20d94846 100644 --- a/Sources/WebSocketKit/WebSocketClient.swift +++ b/Sources/WebSocketKit/WebSocketClient.swift @@ -4,6 +4,21 @@ import NIOConcurrencyHelpers import NIOHTTP1 import NIOWebSocket import NIOSSL +#if canImport(Network) && swift(>=5.3) +import NIOTransportServices +#endif + +internal extension String { + var isIPAddress: Bool { + var ipv4Addr = in_addr() + var ipv6Addr = in6_addr() + + return self.withCString { ptr in + inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || + inet_pton(AF_INET6, ptr, &ipv6Addr) == 1 + } + } +} public final class WebSocketClient { public enum Error: Swift.Error, LocalizedError { @@ -44,22 +59,82 @@ public final class WebSocketClient { case .shared(let group): self.group = group case .createNew: - self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + #if canImport(Network) && swift(>=5.3) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + self.group = NIOTSEventLoopGroup() + } else { + self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + } + #else + self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + #endif } self.configuration = configuration } + + fileprivate static func makeBootstrap( + on eventLoop: EventLoopGroup, + host: String, + requiresTLS: Bool, + configuration: Configuration + ) throws -> NIOClientTCPBootstrap { + var bootstrap: NIOClientTCPBootstrap + #if canImport(Network) && swift(>=5.3) + // if eventLoop is compatible with NIOTransportServices create a NIOTSConnectionBootstrap + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { + // if there is a proxy don't create TLS provider as it will be added at a later point + // create NIOClientTCPBootstrap with NIOTS TLS provider + let tlsConfiguration = configuration.tlsConfiguration ?? TLSConfiguration.forClient() + let parameters = tlsConfiguration.getNWProtocolTLSOptions() + let tlsProvider = NIOTSClientTLSProvider(tlsOptions: parameters) + bootstrap = NIOClientTCPBootstrap(tsBootstrap, tls: tlsProvider) + } else if let clientBootstrap = ClientBootstrap(validatingGroup: eventLoop) { + let tlsConfiguration = configuration.tlsConfiguration ?? TLSConfiguration.forClient() + let sslContext = try NIOSSLContext(configuration: tlsConfiguration) + let hostname = (!requiresTLS || host.isIPAddress || host.isEmpty) ? nil : host + let tlsProvider = try NIOSSLClientTLSProvider(context: sslContext, serverHostname: hostname) + bootstrap = NIOClientTCPBootstrap(clientBootstrap, tls: tlsProvider) + } else { + preconditionFailure("Cannot create bootstrap for the supplied EventLoop") + } + #else + if let clientBootstrap = ClientBootstrap(validatingGroup: eventLoop) { + let tlsConfiguration = configuration.tlsConfiguration ?? TLSConfiguration.forClient() + let sslContext = try NIOSSLContext(configuration: tlsConfiguration) + let hostname = (!requiresTLS || host.isIPAddress || host.isEmpty) ? nil : host + let tlsProvider = try NIOSSLClientTLSProvider(context: sslContext, serverHostname: hostname) + bootstrap = NIOClientTCPBootstrap(clientBootstrap, tls: tlsProvider) + } else { + preconditionFailure("Cannot create bootstrap for the supplied EventLoop") + } + #endif + + if requiresTLS { + return bootstrap.enableTLS() + } + + return bootstrap + } + public func connect( scheme: String, host: String, port: Int, path: String = "/", headers: HTTPHeaders = [:], - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket, HTTPResponseHead) -> (), + onRequest: @escaping (HTTPRequestHead) -> () = {_ in } ) -> EventLoopFuture { assert(["ws", "wss"].contains(scheme)) let upgradePromise = self.group.next().makePromise(of: Void.self) - let bootstrap = ClientBootstrap(group: self.group) + + let bootstrap = try! Self.makeBootstrap( + on: self.group, + host: host, + requiresTLS: scheme == "wss", + configuration: self.configuration + ) .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) .channelInitializer { channel in let httpHandler = HTTPInitialRequestHandler( @@ -78,7 +153,9 @@ public final class WebSocketClient { maxFrameSize: self.configuration.maxFrameSize, automaticErrorHandling: true, upgradePipelineHandler: { channel, req in - return WebSocket.client(on: channel, onUpgrade: onUpgrade) + return WebSocket.client(on: channel) { ws in + onUpgrade(ws, req) + } } ) @@ -90,27 +167,16 @@ public final class WebSocketClient { } ) - if scheme == "wss" { - do { - let context = try NIOSSLContext( - configuration: self.configuration.tlsConfiguration ?? .forClient() - ) - let tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: host) - return channel.pipeline.addHandler(tlsHandler).flatMap { - channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config) - }.flatMap { - channel.pipeline.addHandler(httpHandler) - } - } catch { - return channel.pipeline.close(mode: .all) - } - } else { - return channel.pipeline.addHTTPClientHandlers( - leftOverBytesStrategy: .forwardBytes, - withClientUpgrade: config - ).flatMap { - channel.pipeline.addHandler(httpHandler) - } + return channel.pipeline.addHTTPClientHandlers( + leftOverBytesStrategy: .forwardBytes, + withServerUpgrade: config, + withExtraHandlers: [ + HTTPChannelIntercepter(writeInterceptHandler: { (head) in + onRequest(head) + }) + ] + ).flatMap { + channel.pipeline.addHandler(httpHandler) } } diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index a6936d1f..1c5d1772 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -2,13 +2,17 @@ import XCTest import NIO import NIOHTTP1 import NIOWebSocket +#if canImport(Network) && swift(>=5.3) +import NIOTransportServices +#endif @testable import WebSocketKit final class WebSocketKitTests: XCTestCase { + func testWebSocketEcho() throws { - let promise = elg.next().makePromise(of: String.self) - let closePromise = elg.next().makePromise(of: Void.self) - WebSocket.connect(to: "ws://echo.websocket.org", on: elg) { ws in + let promise = self.remoteEventGroup.next().makePromise(of: String.self) + let closePromise = self.remoteEventGroup.next().makePromise(of: Void.self) + WebSocket.connect(to: "ws://echo.websocket.org", on: self.remoteEventGroup) { ws in ws.send("hello") ws.onText { ws, string in promise.succeed(string) @@ -17,11 +21,12 @@ final class WebSocketKitTests: XCTestCase { }.cascadeFailure(to: promise) try XCTAssertEqual(promise.futureResult.wait(), "hello") XCTAssertNoThrow(try closePromise.futureResult.wait()) + } func testWebSocketWithTLSEcho() throws { - let promise = elg.next().makePromise(of: String.self) - WebSocket.connect(to: "wss://echo.websocket.org", on: elg) { ws in + let promise = self.remoteEventGroup.next().makePromise(of: String.self) + WebSocket.connect(to: "wss://echo.websocket.org", on: self.remoteEventGroup) { ws in ws.send("hello") ws.onText { ws, string in promise.succeed(string) @@ -32,7 +37,12 @@ final class WebSocketKitTests: XCTestCase { } func testBadHost() throws { - XCTAssertThrowsError(try WebSocket.connect(host: "asdf", on: elg) { _ in }.wait()) + XCTAssertThrowsError( + try WebSocket.connect( + host: "asdf", + on: self.remoteEventGroup + ) { _ in }.wait() + ) } func testServerClose() throws { @@ -214,12 +224,29 @@ final class WebSocketKitTests: XCTestCase { } var elg: EventLoopGroup! + var remoteEventGroup: EventLoopGroup! override func setUp() { // needs to be at least two to avoid client / server on same EL timing issues self.elg = MultiThreadedEventLoopGroup(numberOfThreads: 2) + + #if canImport(Network) && swift(>=5.3) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + self.remoteEventGroup = NIOTSEventLoopGroup() + } else { + self.remoteEventGroup = self.elg + } + #else + self.remoteEventGroup = self.elg + #endif } + override func tearDown() { try! self.elg.syncShutdownGracefully() + #if canImport(Network) && swift(>=5.3) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + try! self.remoteEventGroup.syncShutdownGracefully() + } + #endif } }