From ecbc54d428cc20cf759a15176e9e826edec3b918 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 14:58:06 +0200 Subject: [PATCH 01/20] support tls client hello bytes callback in Kestrel --- .../Core/src/HttpsConnectionAdapterOptions.cs | 7 + .../Core/src/ListenOptionsHttpsExtensions.cs | 10 + .../src/Middleware/TlsListenerMiddleware.cs | 125 +++++ .../Kestrel/Core/src/PublicAPI.Unshipped.txt | 2 + .../TlsListenerMiddlewareTests.Units.cs | 510 ++++++++++++++++++ .../TlsListenerMiddlewareTests.cs | 66 +++ 6 files changed, 720 insertions(+) create mode 100644 src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs create mode 100644 src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs create mode 100644 src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs diff --git a/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs b/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs index f13540fa579c..b74fefdb8b96 100644 --- a/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs +++ b/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Net.Security; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; @@ -96,6 +97,12 @@ public void AllowAnyClientCertificate() /// public Action? OnAuthenticate { get; set; } + /// + /// A callback to be invoked to get the TLS client hello bytes. + /// Null by default. + /// + public Action>? TlsClientHelloBytesCallback { get; set; } + /// /// Specifies the maximum amount of time allowed for the TLS/SSL handshake. This must be positive /// or . Defaults to 10 seconds. diff --git a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs index 32bd1dd59889..42d7ac8f0476 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs @@ -5,6 +5,7 @@ using System.Security.Cryptography.X509Certificates; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; using Microsoft.Extensions.DependencyInjection; @@ -197,6 +198,15 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConn listenOptions.IsTls = true; listenOptions.HttpsOptions = httpsOptions; + if (httpsOptions.TlsClientHelloBytesCallback is not null) + { + listenOptions.Use(next => + { + var middleware = new TlsListenerMiddleware(next, httpsOptions.TlsClientHelloBytesCallback); + return middleware.OnTlsClientHelloAsync; + }); + } + listenOptions.Use(next => { var middleware = new HttpsConnectionMiddleware(next, httpsOptions, listenOptions.Protocols, loggerFactory, metrics); diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs new file mode 100644 index 000000000000..096707adbcea --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -0,0 +1,125 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using Microsoft.AspNetCore.Connections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; + +internal sealed class TlsListenerMiddleware +{ + private readonly ConnectionDelegate _next; + private readonly Action> _tlsClientHelloBytesCallback; + + public TlsListenerMiddleware(ConnectionDelegate next, Action> tlsClientHelloBytesCallback) + { + _next = next; + _tlsClientHelloBytesCallback = tlsClientHelloBytesCallback; + } + + /// + /// Sniffs the TLS Client Hello message, and invokes a callback if found. + /// + internal async Task OnTlsClientHelloAsync(ConnectionContext connection) + { + var input = connection.Transport.Input; + + while (true) + { + var result = await input.ReadAsync(); + var buffer = result.Buffer; + + // If the buffer length is less than 6 bytes (handshake + version + length + client-hello byte) + // and no more data is coming, we can't block in a loop here because we will not get more data + if (buffer.Length < 6 && result.IsCompleted) + { + break; + } + + var parseState = TryParseClientHello(buffer, out var clientHelloBytes); + + // no data is consumed, it will be processed by the follow-up middlewares + input.AdvanceTo(buffer.Start); + + switch (parseState) + { + case ClientHelloParseState.NotEnoughData: + continue; + + case ClientHelloParseState.NotTlsClientHello: + await _next(connection); + return; + + case ClientHelloParseState.ValidTlsClientHello: + _tlsClientHelloBytesCallback(connection, clientHelloBytes); + await _next(connection); + return; + } + } + + await _next(connection); + } + + private static ClientHelloParseState TryParseClientHello(ReadOnlySequence buffer, out ReadOnlySequence clientHelloBytes) + { + clientHelloBytes = default; + + if (buffer.Length < 6) + { + return ClientHelloParseState.NotEnoughData; + } + + var reader = new SequenceReader(buffer); + + // Content type must be 0x16 for TLS Handshake + if (!reader.TryRead(out byte contentType) || contentType != 0x16) + { + return ClientHelloParseState.NotTlsClientHello; + } + + // Protocol version + if (!reader.TryReadBigEndian(out short version) || IsValidProtocolVersion(version) == false) + { + return ClientHelloParseState.NotTlsClientHello; + } + + // Record length + if (!reader.TryReadBigEndian(out short recordLength)) + { + return ClientHelloParseState.NotTlsClientHello; + } + + // 5 bytes are + // 1) Handshake (1 byte) + // 2) Protocol version (2 bytes) + // 3) Record length (2 bytes) + if (buffer.Length < 5 + recordLength) + { + return ClientHelloParseState.NotEnoughData; + } + + // byte 6: handshake message type (must be 0x01 for ClientHello) + if (!reader.TryRead(out byte handshakeType) || handshakeType != 0x01) + { + return ClientHelloParseState.NotTlsClientHello; + } + + clientHelloBytes = buffer.Slice(0, 5 + recordLength); + return ClientHelloParseState.ValidTlsClientHello; + } + + private static bool IsValidProtocolVersion(short version) + => version == 0x0002 // SSL 2.0 (0x0002) + || version == 0x0300 // SSL 3.0 (0x0300) + || version == 0x0301 // TLS 1.0 (0x0301) + || version == 0x0302 // TLS 1.1 (0x0302) + || version == 0x0303 // TLS 1.2 (0x0303) + || version == 0x0304; // TLS 1.3 (0x0304) + + private enum ClientHelloParseState + { + NotEnoughData, + NotTlsClientHello, + ValidTlsClientHello + } +} diff --git a/src/Servers/Kestrel/Core/src/PublicAPI.Unshipped.txt b/src/Servers/Kestrel/Core/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..69a983823915 100644 --- a/src/Servers/Kestrel/Core/src/PublicAPI.Unshipped.txt +++ b/src/Servers/Kestrel/Core/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Server.Kestrel.Https.HttpsConnectionAdapterOptions.TlsClientHelloBytesCallback.get -> System.Action>? +Microsoft.AspNetCore.Server.Kestrel.Https.HttpsConnectionAdapterOptions.TlsClientHelloBytesCallback.set -> void diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs new file mode 100644 index 000000000000..696bc24fc898 --- /dev/null +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs @@ -0,0 +1,510 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Net; +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Moq; + +namespace InMemory.FunctionalTests; + +public partial class TlsListenerMiddlewareTests +{ + [Theory] + [MemberData(nameof(ValidClientHelloData))] + public Task OnTlsClientHelloAsync_ValidData(int id, byte[] packetBytes, bool nextMiddlewareInvoked) + => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: true); + + [Theory] + [MemberData(nameof(InvalidClientHelloData))] + public Task OnTlsClientHelloAsync_InvalidData(int id, byte[] packetBytes, bool nextMiddlewareInvoked) + => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: false); + + [Theory] + [MemberData(nameof(ValidClientHelloData_Segmented))] + public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List packets, bool nextMiddlewareInvoked) + => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: true); + + [Theory] + [MemberData(nameof(InvalidClientHelloData_Segmented))] + public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments(int id, List packets, bool nextMiddlewareInvoked) + => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: false); + + private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( + int id, + List packets, + bool nextMiddlewareInvokedExpected, + bool tlsClientHelloCallbackExpected) + { + var serviceContext = new TestServiceContext(LoggerFactory); + var logger = LoggerFactory.CreateLogger(); + var memoryPool = serviceContext.MemoryPoolFactory(); + var transportConnection = new InMemoryTransportConnection(memoryPool, logger); + + var nextMiddlewareInvokedActual = false; + var tlsClientHelloCallbackActual = false; + + var fullLength = packets.Sum(p => p.Length); + + var middleware = new TlsListenerMiddleware( + next: _ => + { + nextMiddlewareInvokedActual = true; + return Task.CompletedTask; + }, + tlsClientHelloBytesCallback: (ctx, data) => + { + tlsClientHelloCallbackActual = true; + + Assert.NotNull(ctx); + Assert.False(data.IsEmpty); + Assert.Equal(fullLength, data.Length); + } + ); + + // write first packet + await transportConnection.Input.WriteAsync(packets[0]); + var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); + + var random = new Random(); + await Task.Delay(millisecondsDelay: random.Next(25, 75)); + + // write all next packets + foreach (var packet in packets.Skip(1)) + { + await transportConnection.Input.WriteAsync(packet); + await Task.Delay(millisecondsDelay: random.Next(25, 75)); + } + await transportConnection.Input.CompleteAsync(); + await middlewareTask; + + Assert.Equal(nextMiddlewareInvokedExpected, nextMiddlewareInvokedActual); + Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); + } + + private async Task RunTlsClientHelloCallbackTest( + int id, + byte[] packetBytes, + bool nextMiddlewareExpected, + bool tlsClientHelloCallbackExpected) + { + var serviceContext = new TestServiceContext(LoggerFactory); + var logger = LoggerFactory.CreateLogger(); + var memoryPool = serviceContext.MemoryPoolFactory(); + var transportConnection = new InMemoryTransportConnection(memoryPool, logger); + + var nextMiddlewareInvokedActual = false; + var tlsClientHelloCallbackActual = false; + + var middleware = new TlsListenerMiddleware( + next: _ => + { + nextMiddlewareInvokedActual = true; + return Task.CompletedTask; + }, + tlsClientHelloBytesCallback: (ctx, data) => + { + tlsClientHelloCallbackActual = true; + + Assert.NotNull(ctx); + Assert.False(data.IsEmpty); + Assert.Equal(packetBytes.Length, data.Length); + } + ); + + await transportConnection.Input.WriteAsync(packetBytes); + await transportConnection.Input.CompleteAsync(); + + // call middleware and expect a callback + await middleware.OnTlsClientHelloAsync(transportConnection); + + Assert.Equal(nextMiddlewareExpected, nextMiddlewareInvokedActual); + Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); + } + + public static IEnumerable ValidClientHelloData() + { + int id = 0; + foreach (var clientHello in new List() { valid_clientHelloHeader, valid_ClientHelloStandard, valid_Tls12ClientHello, valid_Tls13ClientHello, valid_TlsClientHelloNoExtensions }) + { + yield return new object[] { id++, clientHello, true /* invokes next middleware */ }; + } + } + + public static IEnumerable InvalidClientHelloData() + { + int id = 0; + foreach (byte[] clientHello in new List() { invalid_TlsClientHelloHeader, invalid_3BytesMessage, invalid_UnknownProtocolVersion1, invalid_UnknownProtocolVersion2, invalid_IncorrectHandshakeMessageType }) + { + yield return new object[] { id++, clientHello, true /* invokes next middleware */ }; + } + } + + public static IEnumerable ValidClientHelloData_Segmented() + { + int id = 0; + foreach (var clientHello in new List() { valid_clientHelloHeader, valid_ClientHelloStandard, valid_Tls12ClientHello, valid_Tls13ClientHello, valid_TlsClientHelloNoExtensions }) + { + var clientHelloSegments = new List + { + clientHello.Take(1).ToArray(), + clientHello.Skip(1).Take(2).ToArray(), + clientHello.Skip(3).Take(2).ToArray(), + clientHello.Skip(5).Take(1).ToArray(), + clientHello.Skip(6).Take(clientHello.Length - 6).ToArray() + }; + + yield return new object[] { id++, clientHelloSegments, true /* invokes next middleware */ }; + } + } + + public static IEnumerable InvalidClientHelloData_Segmented() + { + int id = 0; + foreach (List clientHelloSegments in new List>() { invalidSegmented_TlsClientHelloHeader }) + { + yield return new object[] { id++, clientHelloSegments, true /* invokes next middleware */ }; + } + } + + private static byte[] valid_clientHelloHeader = + { + // 0x16 = Handshake + 0x16, + // 0x0301 = TLS 1.0 + 0x03, 0x01, + // length = 0x0020 (32 bytes) + 0x00, 0x20, + // Handshake.msg_type (client hello) + 0x01, + // 31 bytes (zeros for simplicity) + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0 + }; + + private static byte[] valid_ClientHelloStandard = + { + // SslPlainText.(ContentType+ProtocolVersion) + 0x16, 0x03, 0x03, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xC7, + // ClientHello.client_version + 0x03, 0x03, + // ClientHello.random + 0x0C, 0x3C, 0x85, 0x78, 0xCA, + 0x67, 0x70, 0xAA, 0x38, 0xCB, + 0x28, 0xBC, 0xDC, 0x3E, 0x30, + 0xBF, 0x11, 0x96, 0x95, 0x1A, + 0xB9, 0xF0, 0x99, 0xA4, 0x91, + 0x09, 0x13, 0xB4, 0x89, 0x94, + 0x27, 0x2E, + // ClientHello.SessionId + 0x00, + // ClientHello.cipher_suites + 0x00, 0x2A, 0xC0, 0x2C, 0xC0, + 0x2B, 0xC0, 0x30, 0xC0, 0x2F, + 0x00, 0x9F, 0x00, 0x9E, 0xC0, + 0x24, 0xC0, 0x23, 0xC0, 0x28, + 0xC0, 0x27, 0xC0, 0x0A, 0xC0, + 0x09, 0xC0, 0x14, 0xC0, 0x13, + 0x00, 0x9D, 0x00, 0x9C, 0x00, + 0x3D, 0x00, 0x3C, 0x00, 0x35, + 0x00, 0x2F, 0x00, 0x0A, + // ClientHello.compression_methods + 0x01, 0x01, + // ClientHello.extension_list_length + 0x00, 0x74, + // Extension.extension_type (server_name) + 0x00, 0x00, + // ServerNameListExtension.length + 0x00, 0x39, + // ServerName.length + 0x00, 0x37, + // ServerName.type + 0x00, + // HostName.length + 0x00, 0x34, + // HostName.bytes + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, + // Extension.extension_type (00 0A) + 0x00, 0x0A, + // Extension 0A + 0x00, 0x08, 0x00, 0x06, 0x00, + 0x1D, 0x00, 0x17, 0x00, 0x18, + // Extension.extension_type (00 0B) + 0x00, 0x0B, + // Extension 0B + 0x00, 0x02, 0x01, 0x00, + // Extension.extension_type (00 0D) + 0x00, 0x0D, + // Extension 0D + 0x00, 0x14, 0x00, 0x12, 0x04, + 0x01, 0x05, 0x01, 0x02, 0x01, + 0x04, 0x03, 0x05, 0x03, 0x02, + 0x03, 0x02, 0x02, 0x06, 0x01, + 0x06, 0x03, + // Extension.extension_type (00 23) + 0x00, 0x23, + // Extension 00 23 + 0x00, 0x00, + // Extension.extension_type (00 17) + 0x00, 0x17, + // Extension 17 + 0x00, 0x00, + // Extension.extension_type (FF 01) + 0xFF, 0x01, + // Extension FF01 + 0x00, 0x01, 0x00 + }; + + private static byte[] valid_Tls12ClientHello = + { + // SslPlainText.(ContentType+ProtocolVersion) + 0x16, 0x03, 0x01, + // SslPlainText.length + 0x00, 0xD1, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xCD, + // ClientHello.client_version + 0x03, 0x03, + // ClientHello.random + 0x0C, 0x3C, 0x85, 0x78, 0xCA, + 0x67, 0x70, 0xAA, 0x38, 0xCB, + 0x28, 0xBC, 0xDC, 0x3E, 0x30, + 0xBF, 0x11, 0x96, 0x95, 0x1A, + 0xB9, 0xF0, 0x99, 0xA4, 0x91, + 0x09, 0x13, 0xB4, 0x89, 0x94, + 0x27, 0x2E, + // ClientHello.SessionId + 0x00, + // ClientHello.cipher_suites_length + 0x00, 0x5C, + // ClientHello.cipher_suites + 0xC0, 0x30, 0xC0, 0x2C, 0xC0, 0x28, 0xC0, 0x24, + 0xC0, 0x14, 0xC0, 0x0A, 0x00, 0x9f, 0x00, 0x6B, + 0x00, 0x39, 0xCC, 0xA9, 0xCC, 0xA8, 0xCC, 0xAA, + 0xFF, 0x85, 0x00, 0xC4, 0x00, 0x88, 0x00, 0x81, + 0x00, 0x9D, 0x00, 0x3D, 0x00, 0x35, 0x00, 0xC0, + 0x00, 0x84, 0xC0, 0x2f, 0xC0, 0x2B, 0xC0, 0x27, + 0xC0, 0x23, 0xC0, 0x13, 0xC0, 0x09, 0x00, 0x9E, + 0x00, 0x67, 0x00, 0x33, 0x00, 0xBE, 0x00, 0x45, + 0x00, 0x9C, 0x00, 0x3C, 0x00, 0x2F, 0x00, 0xBA, + 0x00, 0x41, 0xC0, 0x11, 0xC0, 0x07, 0x00, 0x05, + 0x00, 0x04, 0xC0, 0x12, 0xC0, 0x08, 0x00, 0x16, + 0x00, 0x0a, 0x00, 0xff, + // ClientHello.compression_methods + 0x01, 0x01, + // ClientHello.extension_list_length + 0x00, 0x48, + // Extension.extension_type (ec_point_formats) + 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, + // Extension.extension_type (supported_groups) + 0x00, 0x0A, 0x00, 0x08, 0x00, 0x06, 0x00, 0x1D, + 0x00, 0x17, 0x00, 0x18, + // Extension.extension_type (session_ticket) + 0x00, 0x23, 0x00, 0x00, + // Extension.extension_type (signature_algorithms) + 0x00, 0x0D, 0x00, 0x1C, 0x00, 0x1A, 0x06, 0x01, + 0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03, + 0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03, + // Extension.extension_type (application_level_Protocol) + 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0C, 0x02, 0x68, + 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2F, 0x31, + 0x2E, 0x31 + }; + + private static byte[] valid_Tls13ClientHello = + { + // SslPlainText.(ContentType+ProtocolVersion) + 0x16, 0x03, 0x01, + // SslPlainText.length + 0x01, 0x08, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x01, 0x04, + // ClientHello.client_version + 0x03, 0x03, + // ClientHello.random + 0x0C, 0x3C, 0x85, 0x78, 0xCA, 0x67, 0x70, 0xAA, + 0x38, 0xCB, 0x28, 0xBC, 0xDC, 0x3E, 0x30, 0xBF, + 0x11, 0x96, 0x95, 0x1A, 0xB9, 0xF0, 0x99, 0xA4, + 0x91, 0x09, 0x13, 0xB4, 0x89, 0x94, 0x27, 0x2E, + // ClientHello.SessionId_Length + 0x20, + // ClientHello.SessionId + 0x0C, 0x3C, 0x85, 0x78, 0xCA, 0x67, 0x70, 0xAA, + 0x38, 0xCB, 0x28, 0xBC, 0xDC, 0x3E, 0x30, 0xBF, + 0x11, 0x96, 0x95, 0x1A, 0xB9, 0xF0, 0x99, 0xA4, + 0x91, 0x09, 0x13, 0xB4, 0x89, 0x94, 0x27, 0x2E, + // ClientHello.cipher_suites_length + 0x00, 0x0C, + // ClientHello.cipher_suites + 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0xC0, 0x14, + 0xc0, 0x30, 0x00, 0xFF, + // ClientHello.compression_methods + 0x01, 0x00, + // ClientHello.extension_list_length + 0x00, 0xAF, + // Extension.extension_type (server_name) (10.211.55.2) + 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, + 0x0B, 0x31, 0x30, 0x2E, 0x32, 0x31, 0x31, 0x2E, + 0x35, 0x35, 0x2E, 0x32, + // Extension.extension_type (ec_point_formats) + 0x00, 0x0B, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, + // Extension.extension_type (supported_groups) + 0x00, 0x0A, 0x00, 0x0C, 0x00, 0x0A, 0x00, 0x1D, + 0x00, 0x17, 0x00, 0x1E, 0x00, 0x19, 0x00, 0x18, + // Extension.extension_type (application_level_Protocol) (boo) + 0x00, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, 0x62, + 0x6f, 0x6f, + // Extension.extension_type (encrypt_then_mac) + 0x00, 0x16, 0x00, 0x00, + // Extension.extension_type (extended_master_key_secret) + 0x00, 0x17, 0x00, 0x00, + // Extension.extension_type (signature_algorithms) + 0x00, 0x0D, 0x00, 0x30, 0x00, 0x2E, + 0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03, + 0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03, + 0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03, + 0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, + // Extension.extension_type (supported_versions) + 0x00, 0x2B, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03, + 0x03, 0x03, 0x02, 0x03, 0x01, + // Extension.extension_type (psk_key_exchange_modes) + 0x00, 0x2D, 0x00, 0x02, 0x01, 0x01, + // Extension.extension_type (key_share) + 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1D, + 0x00, 0x20, + 0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03, + 0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03 + }; + + private static byte[] valid_TlsClientHelloNoExtensions = + { + 0x16, 0x03, 0x03, 0x00, 0x39, 0x01, 0x00, 0x00, + 0x35, 0x03, 0x03, 0x62, 0x5d, 0x50, 0x2a, 0x41, + 0x2f, 0xd8, 0xc3, 0x65, 0x35, 0xea, 0x01, 0x70, + 0x03, 0x7e, 0x7e, 0x2d, 0xd4, 0xfe, 0x93, 0x39, + 0xa4, 0x04, 0x66, 0xbb, 0x46, 0x91, 0x41, 0xc3, + 0x48, 0x87, 0x3d, 0x00, 0x00, 0x0e, 0x00, 0x3d, + 0x00, 0x3c, 0x00, 0x0a, 0x00, 0x35, 0x00, 0x2f, + 0x00, 0x05, 0x00, 0x04, 0x01, 0x00 + }; + + private static byte[] invalid_TlsClientHelloHeader = + { + // Handshake - incorrect + 0x01, + // ProtocolVersion + 0x03, 0x04, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xC7, + }; + + private static byte[] invalid_3BytesMessage = + { + // Handshake + 0x016, + // Protocol Version + 0x03, 0x01, + // not enough data - so incorrect + }; + + private static byte[] invalid_UnknownProtocolVersion1 = + { + // Handshake + 0x016, + // ProtocolVersion - incorrect + 0x02, 0x05, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xC7, + }; + + private static byte[] invalid_UnknownProtocolVersion2 = + { + // Handshake + 0x016, + // ProtocolVersion - incorrect + 0x02, 0x01, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xC7, + }; + + private static byte[] invalid_IncorrectHandshakeMessageType = + { + // Handshake + 0x016, + // ProtocolVersion + 0x02, 0x00, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) - incorrect + 0x02, + // Handshake.length + 0x00, 0x00, 0xC7, + }; + + private static List invalidSegmented_TlsClientHelloHeader = new() + { + new byte[] { 0x01, 0x03, 0x04 }, + new byte[] { 0x00, 0xCB }, + new byte[] { 0x01 }, + new byte[] { 0x00, 0x00, 0xC7 }, + }; +} diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs new file mode 100644 index 000000000000..0e42be60a225 --- /dev/null +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace InMemory.FunctionalTests; + +public partial class TlsListenerMiddlewareTests : TestApplicationErrorLoggerLoggedTest +{ + [Fact] + public async Task TlsClientHelloBytesCallback_InvokedAndHasTlsMessageBytes() + { + var tlsClientHelloCallbackInvoked = false; + + var testContext = new TestServiceContext(LoggerFactory); + await using (var server = new TestServer(context => Task.CompletedTask, + testContext, + listenOptions => + { + listenOptions.UseHttps((HttpsConnectionAdapterOptions options) => + { + options.TlsClientHelloBytesCallback = (connection, clientHelloBytes) => + { + Logger.LogDebug("[Received TlsClientHelloBytesCallback] Connection: {0}; TLS client hello buffer: {1}", connection.ConnectionId, clientHelloBytes.Length); + tlsClientHelloCallbackInvoked = true; + Assert.True(clientHelloBytes.Length > 32); + Assert.NotNull(connection); + }; + }); + })) + { + using (var connection = server.CreateConnection()) + { + using (var sslStream = new SslStream(connection.Stream, false, (sender, cert, chain, errors) => true, null)) + { + await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions + { + TargetHost = "localhost", + EnabledSslProtocols = SslProtocols.None + }, CancellationToken.None); + + var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); + await sslStream.WriteAsync(request, 0, request.Length); + } + } + } + + Assert.True(tlsClientHelloCallbackInvoked); + } +} From c2a1d9e2bc532b746e74c9a0b0ca859f7c9d7b83 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 15:11:54 +0200 Subject: [PATCH 02/20] prettify --- .../TlsListenerMiddlewareTests.Units.cs | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs index 696bc24fc898..7db334e82044 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs @@ -145,7 +145,7 @@ private async Task RunTlsClientHelloCallbackTest( public static IEnumerable ValidClientHelloData() { int id = 0; - foreach (var clientHello in new List() { valid_clientHelloHeader, valid_ClientHelloStandard, valid_Tls12ClientHello, valid_Tls13ClientHello, valid_TlsClientHelloNoExtensions }) + foreach (var clientHello in valid_collection) { yield return new object[] { id++, clientHello, true /* invokes next middleware */ }; } @@ -154,7 +154,7 @@ public static IEnumerable ValidClientHelloData() public static IEnumerable InvalidClientHelloData() { int id = 0; - foreach (byte[] clientHello in new List() { invalid_TlsClientHelloHeader, invalid_3BytesMessage, invalid_UnknownProtocolVersion1, invalid_UnknownProtocolVersion2, invalid_IncorrectHandshakeMessageType }) + foreach (byte[] clientHello in invalid_collection) { yield return new object[] { id++, clientHello, true /* invokes next middleware */ }; } @@ -163,7 +163,7 @@ public static IEnumerable InvalidClientHelloData() public static IEnumerable ValidClientHelloData_Segmented() { int id = 0; - foreach (var clientHello in new List() { valid_clientHelloHeader, valid_ClientHelloStandard, valid_Tls12ClientHello, valid_Tls13ClientHello, valid_TlsClientHelloNoExtensions }) + foreach (var clientHello in valid_collection) { var clientHelloSegments = new List { @@ -181,8 +181,30 @@ public static IEnumerable ValidClientHelloData_Segmented() public static IEnumerable InvalidClientHelloData_Segmented() { int id = 0; - foreach (List clientHelloSegments in new List>() { invalidSegmented_TlsClientHelloHeader }) + foreach (var clientHello in invalid_collection) { + var clientHelloSegments = new List(); + if (clientHello.Length >= 1) + { + clientHelloSegments.Add(clientHello.Take(1).ToArray()); + } + if (clientHello.Length >= 3) + { + clientHelloSegments.Add(clientHello.Skip(1).Take(2).ToArray()); + } + if (clientHello.Length >= 5) + { + clientHelloSegments.Add(clientHello.Skip(3).Take(2).ToArray()); + } + if (clientHello.Length >= 6) + { + clientHelloSegments.Add(clientHello.Skip(5).Take(1).ToArray()); + } + if (clientHello.Length >= 7) + { + clientHelloSegments.Add(clientHello.Skip(6).Take(clientHello.Length - 6).ToArray()); + } + yield return new object[] { id++, clientHelloSegments, true /* invokes next middleware */ }; } } @@ -500,11 +522,13 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x00, 0x00, 0xC7, }; - private static List invalidSegmented_TlsClientHelloHeader = new() + private static List valid_collection = new List() + { + valid_clientHelloHeader, valid_ClientHelloStandard, valid_Tls12ClientHello, valid_Tls13ClientHello, valid_TlsClientHelloNoExtensions + }; + + private static List invalid_collection = new List() { - new byte[] { 0x01, 0x03, 0x04 }, - new byte[] { 0x00, 0xCB }, - new byte[] { 0x01 }, - new byte[] { 0x00, 0x00, 0xC7 }, + invalid_TlsClientHelloHeader, invalid_3BytesMessage, invalid_UnknownProtocolVersion1, invalid_UnknownProtocolVersion2, invalid_IncorrectHandshakeMessageType }; } From 5a911646d4f36d3dee2e5ddc1b3a723171c3e29c Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 18:17:46 +0200 Subject: [PATCH 03/20] address comments --- .../src/Middleware/TlsListenerMiddleware.cs | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index 096707adbcea..e147c40fca9d 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -41,20 +41,20 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) // no data is consumed, it will be processed by the follow-up middlewares input.AdvanceTo(buffer.Start); - switch (parseState) + if (parseState == ClientHelloParseState.NotEnoughData) { - case ClientHelloParseState.NotEnoughData: - continue; - - case ClientHelloParseState.NotTlsClientHello: - await _next(connection); - return; + continue; + } - case ClientHelloParseState.ValidTlsClientHello: - _tlsClientHelloBytesCallback(connection, clientHelloBytes); - await _next(connection); - return; + if (parseState == ClientHelloParseState.ValidTlsClientHello) + { + _tlsClientHelloBytesCallback(connection, clientHelloBytes); } + + // Here either it's a valid TLS client hello or definitely not a TLS client hello. + // Anyway we can continue with the middleware pipeline + await _next(connection); + break; } await _next(connection); @@ -78,7 +78,7 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence } // Protocol version - if (!reader.TryReadBigEndian(out short version) || IsValidProtocolVersion(version) == false) + if (!reader.TryReadBigEndian(out short version) || !IsValidProtocolVersion(version)) { return ClientHelloParseState.NotTlsClientHello; } @@ -109,14 +109,14 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence } private static bool IsValidProtocolVersion(short version) - => version == 0x0002 // SSL 2.0 (0x0002) - || version == 0x0300 // SSL 3.0 (0x0300) - || version == 0x0301 // TLS 1.0 (0x0301) - || version == 0x0302 // TLS 1.1 (0x0302) - || version == 0x0303 // TLS 1.2 (0x0303) - || version == 0x0304; // TLS 1.3 (0x0304) - - private enum ClientHelloParseState + => version is 0x0002 // SSL 2.0 (0x0002) + or 0x0300 // SSL 3.0 (0x0300) + or 0x0301 // TLS 1.0 (0x0301) + or 0x0302 // TLS 1.1 (0x0302) + or 0x0303 // TLS 1.2 (0x0303) + or 0x0304; // TLS 1.3 (0x0304) + + private enum ClientHelloParseState : byte { NotEnoughData, NotTlsClientHello, From 82019397f08ada56bebbd6920cff82eeb8372683 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 18:20:03 +0200 Subject: [PATCH 04/20] fix duplicate next middleware --- src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index e147c40fca9d..a08ca0d0b147 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -53,7 +53,6 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) // Here either it's a valid TLS client hello or definitely not a TLS client hello. // Anyway we can continue with the middleware pipeline - await _next(connection); break; } From 8e5ed75dd8d1fdc6e8f27a51bbc6f48717edd2fb Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 18:52:13 +0200 Subject: [PATCH 05/20] nit --- .../Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index a08ca0d0b147..5408ca53c95f 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; +using System.Diagnostics; using Microsoft.AspNetCore.Connections; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; @@ -53,6 +54,7 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) // Here either it's a valid TLS client hello or definitely not a TLS client hello. // Anyway we can continue with the middleware pipeline + Debug.Assert(parseState is ClientHelloParseState.ValidTlsClientHello or ClientHelloParseState.NotTlsClientHello); break; } From 48e07f86b7e6aed42b3934d6f207777f0ab66084 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 18:52:59 +0200 Subject: [PATCH 06/20] nit nit --- .../Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index 5408ca53c95f..c5c3a90788b8 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -52,10 +52,8 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) _tlsClientHelloBytesCallback(connection, clientHelloBytes); } - // Here either it's a valid TLS client hello or definitely not a TLS client hello. - // Anyway we can continue with the middleware pipeline Debug.Assert(parseState is ClientHelloParseState.ValidTlsClientHello or ClientHelloParseState.NotTlsClientHello); - break; + break; // We can continue with the middleware pipeline } await _next(connection); From 1fcffb864863f48b0f61a97cbc5611b7e3fbac50 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 21:46:21 +0200 Subject: [PATCH 07/20] comments 1 --- .../Core/src/ListenOptionsHttpsExtensions.cs | 9 ++++ .../src/Middleware/TlsListenerMiddleware.cs | 51 +++++++++++-------- .../TlsListenerMiddlewareTests.cs | 1 + 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs index 42d7ac8f0476..4003e09474b0 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs @@ -270,6 +270,15 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, TlsHandsh listenOptions.IsTls = true; listenOptions.HttpsCallbackOptions = callbackOptions; + if (listenOptions.HttpsOptions?.TlsClientHelloBytesCallback is not null) + { + listenOptions.Use(next => + { + var middleware = new TlsListenerMiddleware(next, listenOptions.HttpsOptions.TlsClientHelloBytesCallback); + return middleware.OnTlsClientHelloAsync; + }); + } + listenOptions.Use(next => { // Set the list of protocols from listen options. diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index c5c3a90788b8..a5597da5dee6 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -30,35 +30,46 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) var result = await input.ReadAsync(); var buffer = result.Buffer; - // If the buffer length is less than 6 bytes (handshake + version + length + client-hello byte) - // and no more data is coming, we can't block in a loop here because we will not get more data - if (buffer.Length < 6 && result.IsCompleted) + try { - break; + // If the buffer length is less than 6 bytes (handshake + version + length + client-hello byte) + // and no more data is coming, we can't block in a loop here because we will not get more data + if (buffer.Length < 6 && result.IsCompleted) + { + break; + } + + var parseState = TryParseClientHello(buffer, out var clientHelloBytes); + + if (parseState == ClientHelloParseState.NotEnoughData) + { + continue; + } + + if (parseState == ClientHelloParseState.ValidTlsClientHello) + { + _tlsClientHelloBytesCallback(connection, clientHelloBytes); + } + + Debug.Assert(parseState is ClientHelloParseState.ValidTlsClientHello or ClientHelloParseState.NotTlsClientHello); + break; // We can continue with the middleware pipeline } - - var parseState = TryParseClientHello(buffer, out var clientHelloBytes); - - // no data is consumed, it will be processed by the follow-up middlewares - input.AdvanceTo(buffer.Start); - - if (parseState == ClientHelloParseState.NotEnoughData) + finally { - continue; + input.AdvanceTo(buffer.Start); } - - if (parseState == ClientHelloParseState.ValidTlsClientHello) - { - _tlsClientHelloBytesCallback(connection, clientHelloBytes); - } - - Debug.Assert(parseState is ClientHelloParseState.ValidTlsClientHello or ClientHelloParseState.NotTlsClientHello); - break; // We can continue with the middleware pipeline } await _next(connection); } + /// + /// RFCs + /// ---- + /// TLS 1.1: https://datatracker.ietf.org/doc/html/rfc4346#section-6.2 + /// TLS 1.2: https://datatracker.ietf.org/doc/html/rfc5246#section-6.2 + /// TLS 1.3: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 + /// private static ClientHelloParseState TryParseClientHello(ReadOnlySequence buffer, out ReadOnlySequence clientHelloBytes) { clientHelloBytes = default; diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs index 0e42be60a225..da064aaa2c03 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs @@ -57,6 +57,7 @@ await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); await sslStream.WriteAsync(request, 0, request.Length); + await sslStream.ReadAsync(new Memory(new byte[1024])); } } } From 058116345dad53d4c3bb3b6284827c354a79f823 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 22:21:45 +0200 Subject: [PATCH 08/20] exit even if not enough data --- .../src/Middleware/TlsListenerMiddleware.cs | 9 ++++++++- .../TlsListenerMiddlewareTests.Units.cs | 18 +++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index a5597da5dee6..44fba138e475 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -34,7 +34,7 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) { // If the buffer length is less than 6 bytes (handshake + version + length + client-hello byte) // and no more data is coming, we can't block in a loop here because we will not get more data - if (buffer.Length < 6 && result.IsCompleted) + if (result.IsCompleted && buffer.Length < 6) { break; } @@ -43,6 +43,13 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) if (parseState == ClientHelloParseState.NotEnoughData) { + // if no data will be added, and we still lack enough bytes + // we can't block in a loop, so just exit + if (result.IsCompleted) + { + break; + } + continue; } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs index 7db334e82044..60016cafde8e 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs @@ -480,6 +480,21 @@ public static IEnumerable InvalidClientHelloData_Segmented() // not enough data - so incorrect }; + private static byte[] invalid_9BytesMessage = + { + // 0x16 = Handshake + 0x16, + // 0x0301 = TLS 1.0 + 0x03, 0x01, + // length = 0x0020 (32 bytes) + 0x00, 0x20, + // Handshake.msg_type (client hello) + 0x01, + // should have 31 bytes (zeros for simplicity) + 0, 0, 0 + // no other data here - incorrect + }; + private static byte[] invalid_UnknownProtocolVersion1 = { // Handshake @@ -529,6 +544,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() private static List invalid_collection = new List() { - invalid_TlsClientHelloHeader, invalid_3BytesMessage, invalid_UnknownProtocolVersion1, invalid_UnknownProtocolVersion2, invalid_IncorrectHandshakeMessageType + invalid_TlsClientHelloHeader, invalid_3BytesMessage, invalid_9BytesMessage, + invalid_UnknownProtocolVersion1, invalid_UnknownProtocolVersion2, invalid_IncorrectHandshakeMessageType }; } From 2bbd480261dd29ce5fa38e31d22492820176ef03 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 23 Apr 2025 23:49:19 +0200 Subject: [PATCH 09/20] dont re-read the same data --- .../src/Middleware/TlsListenerMiddleware.cs | 14 +++-- .../InMemoryTransportConnection.cs | 8 ++- .../TlsListenerMiddlewareTests.Units.cs | 54 +++++++++++++++++++ 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index 44fba138e475..fb1be8a09137 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -24,6 +24,7 @@ public TlsListenerMiddleware(ConnectionDelegate next, Action _reader.ReadAsyncCounter; + private class ObservablePipeReader : PipeReader { private readonly PipeReader _reader; private readonly TaskCompletionSource _tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public int ReadAsyncCounter { get; private set; } = 0; public Task WaitForReadTask => _tcs.Task; public ObservablePipeReader(PipeReader reader) @@ -144,6 +147,7 @@ public override void Complete(Exception exception = null) public override ValueTask ReadAsync(CancellationToken cancellationToken = default) { + ReadAsyncCounter++; var task = _reader.ReadAsync(cancellationToken); if (_tcs.Task.IsCompleted) @@ -152,7 +156,7 @@ public override ValueTask ReadAsync(CancellationToken cancellationTo } return new ValueTask(new ObservableValueTask(task, _tcs), 0); - } + } public override bool TryRead(out ReadResult result) { diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs index 60016cafde8e..5063a8149fc7 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs @@ -25,6 +25,7 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Moq; +using static Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport.InMemoryTransportConnection; namespace InMemory.FunctionalTests; @@ -50,6 +51,59 @@ public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List packets, bool nextMiddlewareInvoked) => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: false); + [Fact] + public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() + { + var serviceContext = new TestServiceContext(LoggerFactory); + var logger = LoggerFactory.CreateLogger(); + var memoryPool = serviceContext.MemoryPoolFactory(); + var transportConnection = new InMemoryTransportConnection(memoryPool, logger); + + var nextMiddlewareInvoked = false; + var tlsClientHelloCallbackInvoked = false; + + var middleware = new TlsListenerMiddleware( + next: ctx => + { + nextMiddlewareInvoked = true; + var readResult = ctx.Transport.Input.ReadAsync(); + Assert.Equal(6, readResult.Result.Buffer.Length); + + return Task.CompletedTask; + }, + tlsClientHelloBytesCallback: (ctx, data) => + { + tlsClientHelloCallbackInvoked = true; + } + ); + + await transportConnection.Input.WriteAsync(new byte[1] { 0x16 }); + var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); + await Task.Delay(TimeSpan.FromMilliseconds(25)); + + await transportConnection.Input.WriteAsync(new byte[2] { 0x03, 0x01 }); + await Task.Delay(TimeSpan.FromMilliseconds(25)); + + await transportConnection.Input.WriteAsync(new byte[2] { 0x00, 0x20 }); + await Task.Delay(TimeSpan.FromMilliseconds(25)); + + // not correct TLS client hello byte; + // meaning we will not invoke the callback and advance request processing + await transportConnection.Input.WriteAsync(new byte[1] { 0x15 }); + await Task.Delay(TimeSpan.FromMilliseconds(25)); + + await transportConnection.Input.CompleteAsync(); + + // ensuring that we have read only 5 times (ReadAsync() is called 5 times) + var observableTransport = transportConnection.Transport as ObservableDuplexPipe; + Assert.NotNull(observableTransport); + Assert.Equal(5, observableTransport.ReadAsyncCounter); + + await middlewareTask; + Assert.True(nextMiddlewareInvoked); + Assert.False(tlsClientHelloCallbackInvoked); + } + private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( int id, List packets, From 6f9a6ded4e86b51652af7f06c6bb2bb878ce14bd Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Thu, 24 Apr 2025 15:27:01 +0200 Subject: [PATCH 10/20] increase delays --- .../TlsListenerMiddlewareTests.Units.cs | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs index 5063a8149fc7..1f4be258f178 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs @@ -67,7 +67,7 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() { nextMiddlewareInvoked = true; var readResult = ctx.Transport.Input.ReadAsync(); - Assert.Equal(6, readResult.Result.Buffer.Length); + Assert.Equal(5, readResult.Result.Buffer.Length); return Task.CompletedTask; }, @@ -79,25 +79,20 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() await transportConnection.Input.WriteAsync(new byte[1] { 0x16 }); var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); - await Task.Delay(TimeSpan.FromMilliseconds(25)); + await Task.Delay(TimeSpan.FromMilliseconds(75)); await transportConnection.Input.WriteAsync(new byte[2] { 0x03, 0x01 }); - await Task.Delay(TimeSpan.FromMilliseconds(25)); + await Task.Delay(TimeSpan.FromMilliseconds(75)); await transportConnection.Input.WriteAsync(new byte[2] { 0x00, 0x20 }); - await Task.Delay(TimeSpan.FromMilliseconds(25)); - - // not correct TLS client hello byte; - // meaning we will not invoke the callback and advance request processing - await transportConnection.Input.WriteAsync(new byte[1] { 0x15 }); - await Task.Delay(TimeSpan.FromMilliseconds(25)); + await Task.Delay(TimeSpan.FromMilliseconds(75)); await transportConnection.Input.CompleteAsync(); - // ensuring that we have read only 5 times (ReadAsync() is called 5 times) + // ensuring that we have read only 4 times (ReadAsync() is called 4 times) var observableTransport = transportConnection.Transport as ObservableDuplexPipe; Assert.NotNull(observableTransport); - Assert.Equal(5, observableTransport.ReadAsyncCounter); + Assert.Equal(4, observableTransport.ReadAsyncCounter); await middlewareTask; Assert.True(nextMiddlewareInvoked); From 4f6bc2ba7f452ff343a88fc92ea09250df38b995 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Thu, 24 Apr 2025 15:27:49 +0200 Subject: [PATCH 11/20] await the middleware --- .../TlsListenerMiddlewareTests.Units.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs index 1f4be258f178..f00abf55ae6c 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs @@ -89,14 +89,14 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() await transportConnection.Input.CompleteAsync(); + await middlewareTask; + Assert.True(nextMiddlewareInvoked); + Assert.False(tlsClientHelloCallbackInvoked); + // ensuring that we have read only 4 times (ReadAsync() is called 4 times) var observableTransport = transportConnection.Transport as ObservableDuplexPipe; Assert.NotNull(observableTransport); Assert.Equal(4, observableTransport.ReadAsyncCounter); - - await middlewareTask; - Assert.True(nextMiddlewareInvoked); - Assert.False(tlsClientHelloCallbackInvoked); } private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( From ab06730ad3632e6ab14c8c3ef03d0d18ff30e3ac Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Thu, 24 Apr 2025 17:49:13 +0200 Subject: [PATCH 12/20] another test fix --- .../TlsListenerMiddlewareTests.Units.cs | 4 ++-- .../InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs index f00abf55ae6c..8ec2615f9781 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs @@ -93,10 +93,10 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() Assert.True(nextMiddlewareInvoked); Assert.False(tlsClientHelloCallbackInvoked); - // ensuring that we have read only 4 times (ReadAsync() is called 4 times) + // ensuring that we have read limited number of times var observableTransport = transportConnection.Transport as ObservableDuplexPipe; Assert.NotNull(observableTransport); - Assert.Equal(4, observableTransport.ReadAsyncCounter); + Assert.True(observableTransport.ReadAsyncCounter is > 2 && observableTransport.ReadAsyncCounter is <= 5); } private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs index da064aaa2c03..5d93200d5bfa 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs @@ -23,6 +23,8 @@ namespace InMemory.FunctionalTests; public partial class TlsListenerMiddlewareTests : TestApplicationErrorLoggerLoggedTest { + private static readonly X509Certificate2 _x509Certificate2 = TestResources.GetTestCertificate(); + [Fact] public async Task TlsClientHelloBytesCallback_InvokedAndHasTlsMessageBytes() { @@ -33,9 +35,9 @@ public async Task TlsClientHelloBytesCallback_InvokedAndHasTlsMessageBytes() testContext, listenOptions => { - listenOptions.UseHttps((HttpsConnectionAdapterOptions options) => + listenOptions.UseHttps(_x509Certificate2, httpsOptions => { - options.TlsClientHelloBytesCallback = (connection, clientHelloBytes) => + httpsOptions.TlsClientHelloBytesCallback = (connection, clientHelloBytes) => { Logger.LogDebug("[Received TlsClientHelloBytesCallback] Connection: {0}; TLS client hello buffer: {1}", connection.ConnectionId, clientHelloBytes.Length); tlsClientHelloCallbackInvoked = true; From 898c41543aa72f1c52e7f7566906ad3abb743eac Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Fri, 25 Apr 2025 12:25:10 +0200 Subject: [PATCH 13/20] move tls client hello check --- .../Core/src/Middleware/TlsListenerMiddleware.cs | 12 ++++++------ .../TlsListenerMiddlewareTests.Units.cs | 8 ++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index fb1be8a09137..32e86f8e13fc 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -114,6 +114,12 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence return ClientHelloParseState.NotTlsClientHello; } + // byte 6: handshake message type (must be 0x01 for ClientHello) + if (!reader.TryRead(out byte handshakeType) || handshakeType != 0x01) + { + return ClientHelloParseState.NotTlsClientHello; + } + // 5 bytes are // 1) Handshake (1 byte) // 2) Protocol version (2 bytes) @@ -123,12 +129,6 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence return ClientHelloParseState.NotEnoughData; } - // byte 6: handshake message type (must be 0x01 for ClientHello) - if (!reader.TryRead(out byte handshakeType) || handshakeType != 0x01) - { - return ClientHelloParseState.NotTlsClientHello; - } - clientHelloBytes = buffer.Slice(0, 5 + recordLength); return ClientHelloParseState.ValidTlsClientHello; } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs index 8ec2615f9781..82572f2482f3 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs @@ -79,13 +79,8 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() await transportConnection.Input.WriteAsync(new byte[1] { 0x16 }); var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); - await Task.Delay(TimeSpan.FromMilliseconds(75)); - await transportConnection.Input.WriteAsync(new byte[2] { 0x03, 0x01 }); - await Task.Delay(TimeSpan.FromMilliseconds(75)); - await transportConnection.Input.WriteAsync(new byte[2] { 0x00, 0x20 }); - await Task.Delay(TimeSpan.FromMilliseconds(75)); await transportConnection.Input.CompleteAsync(); @@ -96,7 +91,8 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() // ensuring that we have read limited number of times var observableTransport = transportConnection.Transport as ObservableDuplexPipe; Assert.NotNull(observableTransport); - Assert.True(observableTransport.ReadAsyncCounter is > 2 && observableTransport.ReadAsyncCounter is <= 5); + Assert.True(observableTransport.ReadAsyncCounter is >= 2 && observableTransport.ReadAsyncCounter is <= 5, + $"Expected ReadAsync() to happen about 2-5 times. Actually happened {observableTransport.ReadAsyncCounter} times."); } private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( From 4f87edd2128cbea14e9da4e5089d23ac2adb3f57 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Fri, 25 Apr 2025 12:58:05 +0200 Subject: [PATCH 14/20] move tests --- .../test/TestHelpers/ObservablePipeReader.cs | 47 +++++++++++++ .../test/TlsListenerMiddlewareTests.cs} | 69 ++++++++++--------- .../TlsListenerMiddlewareTests.cs | 2 +- 3 files changed, 86 insertions(+), 32 deletions(-) create mode 100644 src/Servers/Kestrel/Core/test/TestHelpers/ObservablePipeReader.cs rename src/Servers/Kestrel/{test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs => Core/test/TlsListenerMiddlewareTests.cs} (90%) diff --git a/src/Servers/Kestrel/Core/test/TestHelpers/ObservablePipeReader.cs b/src/Servers/Kestrel/Core/test/TestHelpers/ObservablePipeReader.cs new file mode 100644 index 000000000000..299d6491b46b --- /dev/null +++ b/src/Servers/Kestrel/Core/test/TestHelpers/ObservablePipeReader.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests.TestHelpers; + +internal class ObservablePipeReader : PipeReader +{ + private readonly PipeReader _inner; + + public ObservablePipeReader(PipeReader reader) + { + _inner = reader; + } + + /// + /// Number of times was called. + /// + public int ReadAsyncCounter { get; private set; } + + public override void AdvanceTo(SequencePosition consumed) + => _inner.AdvanceTo(consumed); + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + => _inner.AdvanceTo(consumed, examined); + + public override void CancelPendingRead() + => _inner.CancelPendingRead(); + + public override void Complete(Exception exception = null) + => _inner.Complete(exception); + + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + ReadAsyncCounter++; + return _inner.ReadAsync(cancellationToken); + } + + public override bool TryRead(out ReadResult result) + { + return _inner.TryRead(out result); + } +} diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs similarity index 90% rename from src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs rename to src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs index 82572f2482f3..1008447e9fec 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs +++ b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs @@ -19,15 +19,14 @@ using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; +using Microsoft.AspNetCore.Server.Kestrel.Core.Tests.TestHelpers; using Microsoft.AspNetCore.Server.Kestrel.Https; -using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Moq; -using static Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport.InMemoryTransportConnection; -namespace InMemory.FunctionalTests; +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; public partial class TlsListenerMiddlewareTests { @@ -54,10 +53,14 @@ public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments(int id, List(); - var memoryPool = serviceContext.MemoryPoolFactory(); - var transportConnection = new InMemoryTransportConnection(memoryPool, logger); + var serviceContext = new TestServiceContext(); + + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); var nextMiddlewareInvoked = false; var tlsClientHelloCallbackInvoked = false; @@ -77,22 +80,19 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() } ); - await transportConnection.Input.WriteAsync(new byte[1] { 0x16 }); + await writer.WriteAsync(new byte[1] { 0x16 }); var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); - await transportConnection.Input.WriteAsync(new byte[2] { 0x03, 0x01 }); - await transportConnection.Input.WriteAsync(new byte[2] { 0x00, 0x20 }); - - await transportConnection.Input.CompleteAsync(); + await writer.WriteAsync(new byte[2] { 0x03, 0x01 }); + await writer.WriteAsync(new byte[2] { 0x00, 0x20 }); + await writer.CompleteAsync(); await middlewareTask; Assert.True(nextMiddlewareInvoked); Assert.False(tlsClientHelloCallbackInvoked); // ensuring that we have read limited number of times - var observableTransport = transportConnection.Transport as ObservableDuplexPipe; - Assert.NotNull(observableTransport); - Assert.True(observableTransport.ReadAsyncCounter is >= 2 && observableTransport.ReadAsyncCounter is <= 5, - $"Expected ReadAsync() to happen about 2-5 times. Actually happened {observableTransport.ReadAsyncCounter} times."); + Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 5, + $"Expected ReadAsync() to happen about 2-5 times. Actually happened {reader.ReadAsyncCounter} times."); } private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( @@ -101,10 +101,12 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( bool nextMiddlewareInvokedExpected, bool tlsClientHelloCallbackExpected) { - var serviceContext = new TestServiceContext(LoggerFactory); - var logger = LoggerFactory.CreateLogger(); - var memoryPool = serviceContext.MemoryPoolFactory(); - var transportConnection = new InMemoryTransportConnection(memoryPool, logger); + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); var nextMiddlewareInvokedActual = false; var tlsClientHelloCallbackActual = false; @@ -112,7 +114,7 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( var fullLength = packets.Sum(p => p.Length); var middleware = new TlsListenerMiddleware( - next: _ => + next: ctx => { nextMiddlewareInvokedActual = true; return Task.CompletedTask; @@ -128,7 +130,7 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( ); // write first packet - await transportConnection.Input.WriteAsync(packets[0]); + await writer.WriteAsync(packets[0]); var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); var random = new Random(); @@ -137,10 +139,10 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( // write all next packets foreach (var packet in packets.Skip(1)) { - await transportConnection.Input.WriteAsync(packet); + await writer.WriteAsync(packet); await Task.Delay(millisecondsDelay: random.Next(25, 75)); } - await transportConnection.Input.CompleteAsync(); + await writer.CompleteAsync(); await middlewareTask; Assert.Equal(nextMiddlewareInvokedExpected, nextMiddlewareInvokedActual); @@ -153,18 +155,23 @@ private async Task RunTlsClientHelloCallbackTest( bool nextMiddlewareExpected, bool tlsClientHelloCallbackExpected) { - var serviceContext = new TestServiceContext(LoggerFactory); - var logger = LoggerFactory.CreateLogger(); - var memoryPool = serviceContext.MemoryPoolFactory(); - var transportConnection = new InMemoryTransportConnection(memoryPool, logger); + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); var nextMiddlewareInvokedActual = false; var tlsClientHelloCallbackActual = false; var middleware = new TlsListenerMiddleware( - next: _ => + next: ctx => { nextMiddlewareInvokedActual = true; + var readResult = ctx.Transport.Input.ReadAsync(); + Assert.Equal(packetBytes.Length, readResult.Result.Buffer.Length); + return Task.CompletedTask; }, tlsClientHelloBytesCallback: (ctx, data) => @@ -177,8 +184,8 @@ private async Task RunTlsClientHelloCallbackTest( } ); - await transportConnection.Input.WriteAsync(packetBytes); - await transportConnection.Input.CompleteAsync(); + await writer.WriteAsync(packetBytes); + await writer.CompleteAsync(); // call middleware and expect a callback await middleware.OnTlsClientHelloAsync(transportConnection); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs index 5d93200d5bfa..b57ca2405ba4 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs @@ -21,7 +21,7 @@ namespace InMemory.FunctionalTests; -public partial class TlsListenerMiddlewareTests : TestApplicationErrorLoggerLoggedTest +public class TlsListenerMiddlewareTests : TestApplicationErrorLoggerLoggedTest { private static readonly X509Certificate2 _x509Certificate2 = TestResources.GetTestCertificate(); From 7856693aa41ddd29de6ea1b836b09d29c1a97fbc Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Fri, 25 Apr 2025 17:46:44 +0200 Subject: [PATCH 15/20] dont register callback from `TlsClientHelloBytesCallback` to not have duplicate registration --- .../Kestrel/Core/src/ListenOptionsHttpsExtensions.cs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs index 4003e09474b0..42d7ac8f0476 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs @@ -270,15 +270,6 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, TlsHandsh listenOptions.IsTls = true; listenOptions.HttpsCallbackOptions = callbackOptions; - if (listenOptions.HttpsOptions?.TlsClientHelloBytesCallback is not null) - { - listenOptions.Use(next => - { - var middleware = new TlsListenerMiddleware(next, listenOptions.HttpsOptions.TlsClientHelloBytesCallback); - return middleware.OnTlsClientHelloAsync; - }); - } - listenOptions.Use(next => { // Set the list of protocols from listen options. From d64fb451c04c30c2d9e6558d7bfa6a68063e6c66 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Fri, 25 Apr 2025 19:32:20 +0200 Subject: [PATCH 16/20] tests --- .../src/Middleware/TlsListenerMiddleware.cs | 3 +- .../Core/test/TlsListenerMiddlewareTests.cs | 145 +++++++----------- 2 files changed, 57 insertions(+), 91 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs index 32e86f8e13fc..01bc75553a09 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -134,8 +134,7 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence } private static bool IsValidProtocolVersion(short version) - => version is 0x0002 // SSL 2.0 (0x0002) - or 0x0300 // SSL 3.0 (0x0300) + => version is 0x0300 // SSL 3.0 (0x0300) or 0x0301 // TLS 1.0 (0x0301) or 0x0302 // TLS 1.1 (0x0302) or 0x0303 // TLS 1.2 (0x0303) diff --git a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs index 1008447e9fec..590555e3c22e 100644 --- a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs @@ -278,99 +278,64 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0 }; - private static byte[] valid_ClientHelloStandard = + private static byte[] valid_Ssl3ClientHello = { - // SslPlainText.(ContentType+ProtocolVersion) - 0x16, 0x03, 0x03, - // SslPlainText.length - 0x00, 0xCB, - // Handshake.msg_type (client hello) - 0x01, - // Handshake.length - 0x00, 0x00, 0xC7, - // ClientHello.client_version - 0x03, 0x03, - // ClientHello.random - 0x0C, 0x3C, 0x85, 0x78, 0xCA, - 0x67, 0x70, 0xAA, 0x38, 0xCB, - 0x28, 0xBC, 0xDC, 0x3E, 0x30, - 0xBF, 0x11, 0x96, 0x95, 0x1A, - 0xB9, 0xF0, 0x99, 0xA4, 0x91, - 0x09, 0x13, 0xB4, 0x89, 0x94, - 0x27, 0x2E, - // ClientHello.SessionId - 0x00, - // ClientHello.cipher_suites - 0x00, 0x2A, 0xC0, 0x2C, 0xC0, - 0x2B, 0xC0, 0x30, 0xC0, 0x2F, - 0x00, 0x9F, 0x00, 0x9E, 0xC0, - 0x24, 0xC0, 0x23, 0xC0, 0x28, - 0xC0, 0x27, 0xC0, 0x0A, 0xC0, - 0x09, 0xC0, 0x14, 0xC0, 0x13, - 0x00, 0x9D, 0x00, 0x9C, 0x00, - 0x3D, 0x00, 0x3C, 0x00, 0x35, - 0x00, 0x2F, 0x00, 0x0A, - // ClientHello.compression_methods - 0x01, 0x01, - // ClientHello.extension_list_length - 0x00, 0x74, - // Extension.extension_type (server_name) - 0x00, 0x00, - // ServerNameListExtension.length - 0x00, 0x39, - // ServerName.length - 0x00, 0x37, - // ServerName.type - 0x00, - // HostName.length - 0x00, 0x34, - // HostName.bytes - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, 0x61, 0x61, 0x61, - 0x61, 0x61, - // Extension.extension_type (00 0A) - 0x00, 0x0A, - // Extension 0A - 0x00, 0x08, 0x00, 0x06, 0x00, - 0x1D, 0x00, 0x17, 0x00, 0x18, - // Extension.extension_type (00 0B) - 0x00, 0x0B, - // Extension 0B - 0x00, 0x02, 0x01, 0x00, - // Extension.extension_type (00 0D) - 0x00, 0x0D, - // Extension 0D - 0x00, 0x14, 0x00, 0x12, 0x04, - 0x01, 0x05, 0x01, 0x02, 0x01, - 0x04, 0x03, 0x05, 0x03, 0x02, - 0x03, 0x02, 0x02, 0x06, 0x01, - 0x06, 0x03, - // Extension.extension_type (00 23) - 0x00, 0x23, - // Extension 00 23 - 0x00, 0x00, - // Extension.extension_type (00 17) - 0x00, 0x17, - // Extension 17 - 0x00, 0x00, - // Extension.extension_type (FF 01) - 0xFF, 0x01, - // Extension FF01 - 0x00, 0x01, 0x00 + 0x16, 0x03, 0x00, // ContentType: Handshake, Version: SSL 3.0 + 0x00, 0x2F, // Length: 47 bytes + 0x01, // Handshake Type: ClientHello + 0x00, 0x00, 0x2B, // Length: 43 bytes + 0x03, 0x00, // Client Version: SSL 3.0 + // Random (32 bytes) + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, + 0x00, // Session ID Length + 0x00, 0x04, // Cipher Suites Length + 0x00, 0x2F, 0x00, 0x35, // Cipher Suites + 0x01, 0x00 // Compression Methods: null + }; + + private static byte[] valid_Tls10ClientHello = + { + 0x16, 0x03, 0x01, // ContentType: Handshake, Version: TLS 1.0 + 0x00, 0x2F, // Length: 47 bytes + 0x01, // Handshake Type: ClientHello + 0x00, 0x00, 0x2B, // Length: 43 bytes + 0x03, 0x01, // Client Version: TLS 1.0 + // Random (32 bytes) + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, + 0x00, // Session ID Length + 0x00, 0x04, // Cipher Suites Length + 0x00, 0x2F, 0x00, 0x35, // Cipher Suites + 0x01, 0x00 // Compression Methods: null + }; + + private static byte[] valid_Tls11ClientHello = + { + 0x16, 0x03, 0x02, // ContentType: Handshake, Version: TLS 1.1 + 0x00, 0x2F, // Length: 47 bytes + 0x01, // Handshake Type: ClientHello + 0x00, 0x00, 0x2B, // Length: 43 bytes + 0x03, 0x02, // Client Version: TLS 1.1 + // Random (32 bytes) + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, + 0x00, // Session ID Length + 0x00, 0x04, // Cipher Suites Length + 0x00, 0x2F, 0x00, 0x35, // Cipher Suites: TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA + 0x01, 0x00 // Compression Methods: null }; private static byte[] valid_Tls12ClientHello = { // SslPlainText.(ContentType+ProtocolVersion) - 0x16, 0x03, 0x01, + 0x16, 0x03, 0x03, // SslPlainText.length 0x00, 0xD1, // Handshake.msg_type (client hello) @@ -429,7 +394,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() private static byte[] valid_Tls13ClientHello = { // SslPlainText.(ContentType+ProtocolVersion) - 0x16, 0x03, 0x01, + 0x16, 0x03, 0x04, // SslPlainText.length 0x01, 0x08, // Handshake.msg_type (client hello) @@ -591,7 +556,9 @@ public static IEnumerable InvalidClientHelloData_Segmented() private static List valid_collection = new List() { - valid_clientHelloHeader, valid_ClientHelloStandard, valid_Tls12ClientHello, valid_Tls13ClientHello, valid_TlsClientHelloNoExtensions + valid_clientHelloHeader, valid_Ssl3ClientHello, valid_Tls10ClientHello, + valid_Tls11ClientHello, valid_Tls12ClientHello, valid_Tls13ClientHello, + valid_TlsClientHelloNoExtensions }; private static List invalid_collection = new List() From 7578f2b8e04791d9add972fa60081c7ec72dd0c3 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 30 Apr 2025 18:16:28 +0200 Subject: [PATCH 17/20] address comments --- .../Core/test/TlsListenerMiddlewareTests.cs | 98 +++++++++---------- .../InMemoryTransportConnection.cs | 8 +- 2 files changed, 48 insertions(+), 58 deletions(-) diff --git a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs index 590555e3c22e..175d42ec0176 100644 --- a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs @@ -32,23 +32,23 @@ public partial class TlsListenerMiddlewareTests { [Theory] [MemberData(nameof(ValidClientHelloData))] - public Task OnTlsClientHelloAsync_ValidData(int id, byte[] packetBytes, bool nextMiddlewareInvoked) - => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: true); + public Task OnTlsClientHelloAsync_ValidData(int id, byte[] packetBytes) + => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareShouldBeInvoked: true, tlsClientHelloCallbackExpected: true); [Theory] [MemberData(nameof(InvalidClientHelloData))] - public Task OnTlsClientHelloAsync_InvalidData(int id, byte[] packetBytes, bool nextMiddlewareInvoked) - => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: false); + public Task OnTlsClientHelloAsync_InvalidData(int id, byte[] packetBytes) + => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareShouldBeInvoked: true, tlsClientHelloCallbackExpected: false); [Theory] [MemberData(nameof(ValidClientHelloData_Segmented))] - public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List packets, bool nextMiddlewareInvoked) - => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: true); + public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List packets) + => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareShouldBeInvoked: true, tlsClientHelloCallbackExpected: true); [Theory] [MemberData(nameof(InvalidClientHelloData_Segmented))] - public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments(int id, List packets, bool nextMiddlewareInvoked) - => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: false); + public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments(int id, List packets) + => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareShouldBeInvoked: true, tlsClientHelloCallbackExpected: false); [Fact] public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() @@ -81,7 +81,7 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() ); await writer.WriteAsync(new byte[1] { 0x16 }); - var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); + var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection); await writer.WriteAsync(new byte[2] { 0x03, 0x01 }); await writer.WriteAsync(new byte[2] { 0x00, 0x20 }); await writer.CompleteAsync(); @@ -91,21 +91,20 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() Assert.False(tlsClientHelloCallbackInvoked); // ensuring that we have read limited number of times - Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 5, - $"Expected ReadAsync() to happen about 2-5 times. Actually happened {reader.ReadAsyncCounter} times."); + Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 3, + $"Expected ReadAsync() to happen about 2-3 times. Actually happened {reader.ReadAsyncCounter} times."); } private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( int id, List packets, - bool nextMiddlewareInvokedExpected, + bool nextMiddlewareShouldBeInvoked, bool tlsClientHelloCallbackExpected) { var pipe = new Pipe(); var writer = pipe.Writer; - var reader = new ObservablePipeReader(pipe.Reader); - var transport = new DuplexPipe(reader, writer); + var transport = new DuplexPipe(pipe.Reader, writer); var transportConnection = new DefaultConnectionContext("test", transport, transport); var nextMiddlewareInvokedActual = false; @@ -131,35 +130,30 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( // write first packet await writer.WriteAsync(packets[0]); - var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); - - var random = new Random(); - await Task.Delay(millisecondsDelay: random.Next(25, 75)); + var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection); // write all next packets foreach (var packet in packets.Skip(1)) { await writer.WriteAsync(packet); - await Task.Delay(millisecondsDelay: random.Next(25, 75)); } await writer.CompleteAsync(); await middlewareTask; - Assert.Equal(nextMiddlewareInvokedExpected, nextMiddlewareInvokedActual); + Assert.Equal(nextMiddlewareShouldBeInvoked, nextMiddlewareInvokedActual); Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); } private async Task RunTlsClientHelloCallbackTest( int id, byte[] packetBytes, - bool nextMiddlewareExpected, + bool nextMiddlewareShouldBeInvoked, bool tlsClientHelloCallbackExpected) { var pipe = new Pipe(); var writer = pipe.Writer; - var reader = new ObservablePipeReader(pipe.Reader); - var transport = new DuplexPipe(reader, writer); + var transport = new DuplexPipe(pipe.Reader, writer); var transportConnection = new DefaultConnectionContext("test", transport, transport); var nextMiddlewareInvokedActual = false; @@ -190,32 +184,32 @@ private async Task RunTlsClientHelloCallbackTest( // call middleware and expect a callback await middleware.OnTlsClientHelloAsync(transportConnection); - Assert.Equal(nextMiddlewareExpected, nextMiddlewareInvokedActual); + Assert.Equal(nextMiddlewareShouldBeInvoked, nextMiddlewareInvokedActual); Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); } public static IEnumerable ValidClientHelloData() { int id = 0; - foreach (var clientHello in valid_collection) + foreach (var clientHello in _validCollection) { - yield return new object[] { id++, clientHello, true /* invokes next middleware */ }; + yield return new object[] { id++, clientHello }; } } public static IEnumerable InvalidClientHelloData() { int id = 0; - foreach (byte[] clientHello in invalid_collection) + foreach (byte[] clientHello in _invalidCollection) { - yield return new object[] { id++, clientHello, true /* invokes next middleware */ }; + yield return new object[] { id++, clientHello }; } } public static IEnumerable ValidClientHelloData_Segmented() { int id = 0; - foreach (var clientHello in valid_collection) + foreach (var clientHello in _validCollection) { var clientHelloSegments = new List { @@ -226,14 +220,14 @@ public static IEnumerable ValidClientHelloData_Segmented() clientHello.Skip(6).Take(clientHello.Length - 6).ToArray() }; - yield return new object[] { id++, clientHelloSegments, true /* invokes next middleware */ }; + yield return new object[] { id++, clientHelloSegments }; } } public static IEnumerable InvalidClientHelloData_Segmented() { int id = 0; - foreach (var clientHello in invalid_collection) + foreach (var clientHello in _invalidCollection) { var clientHelloSegments = new List(); if (clientHello.Length >= 1) @@ -257,11 +251,11 @@ public static IEnumerable InvalidClientHelloData_Segmented() clientHelloSegments.Add(clientHello.Skip(6).Take(clientHello.Length - 6).ToArray()); } - yield return new object[] { id++, clientHelloSegments, true /* invokes next middleware */ }; + yield return new object[] { id++, clientHelloSegments }; } } - private static byte[] valid_clientHelloHeader = + private static byte[] _validClientHelloHeader = { // 0x16 = Handshake 0x16, @@ -278,7 +272,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0 }; - private static byte[] valid_Ssl3ClientHello = + private static byte[] _validSsl3ClientHello = { 0x16, 0x03, 0x00, // ContentType: Handshake, Version: SSL 3.0 0x00, 0x2F, // Length: 47 bytes @@ -296,7 +290,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x01, 0x00 // Compression Methods: null }; - private static byte[] valid_Tls10ClientHello = + private static byte[] _validTls10ClientHello = { 0x16, 0x03, 0x01, // ContentType: Handshake, Version: TLS 1.0 0x00, 0x2F, // Length: 47 bytes @@ -314,7 +308,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x01, 0x00 // Compression Methods: null }; - private static byte[] valid_Tls11ClientHello = + private static byte[] _validTls11ClientHello = { 0x16, 0x03, 0x02, // ContentType: Handshake, Version: TLS 1.1 0x00, 0x2F, // Length: 47 bytes @@ -332,7 +326,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x01, 0x00 // Compression Methods: null }; - private static byte[] valid_Tls12ClientHello = + private static byte[] _validTls12ClientHello = { // SslPlainText.(ContentType+ProtocolVersion) 0x16, 0x03, 0x03, @@ -391,7 +385,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x2E, 0x31 }; - private static byte[] valid_Tls13ClientHello = + private static byte[] _validTls13ClientHello = { // SslPlainText.(ContentType+ProtocolVersion) 0x16, 0x03, 0x04, @@ -462,7 +456,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03 }; - private static byte[] valid_TlsClientHelloNoExtensions = + private static byte[] _validTlsClientHelloNoExtensions = { 0x16, 0x03, 0x03, 0x00, 0x39, 0x01, 0x00, 0x00, 0x35, 0x03, 0x03, 0x62, 0x5d, 0x50, 0x2a, 0x41, @@ -474,7 +468,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x00, 0x05, 0x00, 0x04, 0x01, 0x00 }; - private static byte[] invalid_TlsClientHelloHeader = + private static byte[] _invalidTlsClientHelloHeader = { // Handshake - incorrect 0x01, @@ -488,7 +482,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x00, 0x00, 0xC7, }; - private static byte[] invalid_3BytesMessage = + private static byte[] _invalid3BytesMessage = { // Handshake 0x016, @@ -497,7 +491,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() // not enough data - so incorrect }; - private static byte[] invalid_9BytesMessage = + private static byte[] _invalid9BytesMessage = { // 0x16 = Handshake 0x16, @@ -512,7 +506,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() // no other data here - incorrect }; - private static byte[] invalid_UnknownProtocolVersion1 = + private static byte[] _invalidUnknownProtocolVersion1 = { // Handshake 0x016, @@ -526,7 +520,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x00, 0x00, 0xC7, }; - private static byte[] invalid_UnknownProtocolVersion2 = + private static byte[] _invalidUnknownProtocolVersion2 = { // Handshake 0x016, @@ -540,7 +534,7 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x00, 0x00, 0xC7, }; - private static byte[] invalid_IncorrectHandshakeMessageType = + private static byte[] _invalidIncorrectHandshakeMessageType = { // Handshake 0x016, @@ -554,16 +548,16 @@ public static IEnumerable InvalidClientHelloData_Segmented() 0x00, 0x00, 0xC7, }; - private static List valid_collection = new List() + private static List _validCollection = new List() { - valid_clientHelloHeader, valid_Ssl3ClientHello, valid_Tls10ClientHello, - valid_Tls11ClientHello, valid_Tls12ClientHello, valid_Tls13ClientHello, - valid_TlsClientHelloNoExtensions + _validClientHelloHeader, _validSsl3ClientHello, _validTls10ClientHello, + _validTls11ClientHello, _validTls12ClientHello, _validTls13ClientHello, + _validTlsClientHelloNoExtensions }; - private static List invalid_collection = new List() + private static List _invalidCollection = new List() { - invalid_TlsClientHelloHeader, invalid_3BytesMessage, invalid_9BytesMessage, - invalid_UnknownProtocolVersion1, invalid_UnknownProtocolVersion2, invalid_IncorrectHandshakeMessageType + _invalidTlsClientHelloHeader, _invalid3BytesMessage, _invalid9BytesMessage, + _invalidUnknownProtocolVersion1, _invalidUnknownProtocolVersion2, _invalidIncorrectHandshakeMessageType }; } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs index 245027b33330..086694908360 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs @@ -91,7 +91,7 @@ public override async ValueTask DisposeAsync() // This piece of code allows us to wait until the PipeReader has been awaited on. // We need to wrap lots of layers (including the ValueTask) to gain visiblity into when // the machinery for the await happens - internal class ObservableDuplexPipe : IDuplexPipe + private class ObservableDuplexPipe : IDuplexPipe { private readonly ObservablePipeReader _reader; @@ -110,14 +110,11 @@ public ObservableDuplexPipe(IDuplexPipe duplexPipe) public PipeWriter Output { get; } - public int ReadAsyncCounter => _reader.ReadAsyncCounter; - private class ObservablePipeReader : PipeReader { private readonly PipeReader _reader; private readonly TaskCompletionSource _tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public int ReadAsyncCounter { get; private set; } = 0; public Task WaitForReadTask => _tcs.Task; public ObservablePipeReader(PipeReader reader) @@ -147,7 +144,6 @@ public override void Complete(Exception exception = null) public override ValueTask ReadAsync(CancellationToken cancellationToken = default) { - ReadAsyncCounter++; var task = _reader.ReadAsync(cancellationToken); if (_tcs.Task.IsCompleted) @@ -156,7 +152,7 @@ public override ValueTask ReadAsync(CancellationToken cancellationTo } return new ValueTask(new ObservableValueTask(task, _tcs), 0); - } + } public override bool TryRead(out ReadResult result) { From aae7b2f3e2b43f955eed7c1271aa79ef08d808bc Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 30 Apr 2025 18:20:24 +0200 Subject: [PATCH 18/20] added comment --- src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs b/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs index b74fefdb8b96..4fad083d87c3 100644 --- a/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs +++ b/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs @@ -100,6 +100,8 @@ public void AllowAnyClientCertificate() /// /// A callback to be invoked to get the TLS client hello bytes. /// Null by default. + /// If you want to store the bytes from the , + /// copy them into a buffer that you control rather than keeping a reference to the or instances. /// public Action>? TlsClientHelloBytesCallback { get; set; } From 90ef2c3a919b6ab3c08843a42f1b32a7b2770a2c Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 30 Apr 2025 20:02:37 +0200 Subject: [PATCH 19/20] more test changes --- .../Core/test/TlsListenerMiddlewareTests.cs | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs index 175d42ec0176..665246fe7aa1 100644 --- a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs @@ -28,31 +28,35 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; -public partial class TlsListenerMiddlewareTests +public class TlsListenerMiddlewareTests { [Theory] [MemberData(nameof(ValidClientHelloData))] public Task OnTlsClientHelloAsync_ValidData(int id, byte[] packetBytes) - => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareShouldBeInvoked: true, tlsClientHelloCallbackExpected: true); + => RunTlsClientHelloCallbackTest(id, packetBytes, tlsClientHelloCallbackExpected: true); [Theory] [MemberData(nameof(InvalidClientHelloData))] public Task OnTlsClientHelloAsync_InvalidData(int id, byte[] packetBytes) - => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareShouldBeInvoked: true, tlsClientHelloCallbackExpected: false); + => RunTlsClientHelloCallbackTest(id, packetBytes, tlsClientHelloCallbackExpected: false); [Theory] [MemberData(nameof(ValidClientHelloData_Segmented))] public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List packets) - => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareShouldBeInvoked: true, tlsClientHelloCallbackExpected: true); + => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, tlsClientHelloCallbackExpected: true); [Theory] [MemberData(nameof(InvalidClientHelloData_Segmented))] public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments(int id, List packets) - => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareShouldBeInvoked: true, tlsClientHelloCallbackExpected: false); + => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, tlsClientHelloCallbackExpected: false); [Fact] - public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() + public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads() { + /* Current test ensures that we read the input stream only a limited number of times. + * It is a guard against incorrect transport.AdvanceTo() usage leading to infinite loop / more reads than should happen. + */ + var serviceContext = new TestServiceContext(); var pipe = new Pipe(); @@ -91,16 +95,15 @@ public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() Assert.False(tlsClientHelloCallbackInvoked); // ensuring that we have read limited number of times - Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 3, - $"Expected ReadAsync() to happen about 2-3 times. Actually happened {reader.ReadAsyncCounter} times."); + Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 4, + $"Expected ReadAsync() to happen about 2-4 times. Actually happened {reader.ReadAsyncCounter} times."); } private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( int id, List packets, - bool nextMiddlewareShouldBeInvoked, bool tlsClientHelloCallbackExpected) - { + { var pipe = new Pipe(); var writer = pipe.Writer; @@ -116,6 +119,12 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( next: ctx => { nextMiddlewareInvokedActual = true; + if (tlsClientHelloCallbackActual) + { + var readResult = ctx.Transport.Input.ReadAsync(); + Assert.Equal(fullLength, readResult.Result.Buffer.Length); + } + return Task.CompletedTask; }, tlsClientHelloBytesCallback: (ctx, data) => @@ -132,7 +141,14 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( await writer.WriteAsync(packets[0]); var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection); - // write all next packets + + /* It is a race condition (middleware's loop and writes here). + * We don't know specifically how many packets will be read by middleware's loop + * (possibly there are even 2 packets - the first and all others combined). + * The goal here is to try simulate multi-segmented approach and test more cases + */ + + // write all other packets foreach (var packet in packets.Skip(1)) { await writer.WriteAsync(packet); @@ -140,14 +156,13 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( await writer.CompleteAsync(); await middlewareTask; - Assert.Equal(nextMiddlewareShouldBeInvoked, nextMiddlewareInvokedActual); + Assert.True(nextMiddlewareInvokedActual); Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); } private async Task RunTlsClientHelloCallbackTest( int id, byte[] packetBytes, - bool nextMiddlewareShouldBeInvoked, bool tlsClientHelloCallbackExpected) { var pipe = new Pipe(); @@ -184,7 +199,7 @@ private async Task RunTlsClientHelloCallbackTest( // call middleware and expect a callback await middleware.OnTlsClientHelloAsync(transportConnection); - Assert.Equal(nextMiddlewareShouldBeInvoked, nextMiddlewareInvokedActual); + Assert.True(nextMiddlewareInvokedActual); Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); } From e64cc04ec0ffe50fee86578a89e6029da29a5fd7 Mon Sep 17 00:00:00 2001 From: Dmitrii Korolev Date: Wed, 30 Apr 2025 21:33:24 +0200 Subject: [PATCH 20/20] nit --- src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs index 665246fe7aa1..ea3103108ff7 100644 --- a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs @@ -103,7 +103,7 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( int id, List packets, bool tlsClientHelloCallbackExpected) - { + { var pipe = new Pipe(); var writer = pipe.Writer;