Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix preserving messages for stateful reconnect with backplane #60900

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/SignalR/common/Shared/MessageBuffer.cs
Original file line number Diff line number Diff line change
@@ -121,15 +121,16 @@ private async Task RunTimer()

public ValueTask<FlushResult> WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken)
{
return WriteAsyncCore(hubMessage.Message!, hubMessage.GetSerializedMessage(_protocol), cancellationToken);
// Default to HubInvocationMessage as that's the only type we use SerializedHubMessage for currently. Should harden this in the future.
return WriteAsyncCore(hubMessage.Message?.GetType() ?? typeof(HubInvocationMessage), hubMessage.GetSerializedMessage(_protocol), cancellationToken);
}

public ValueTask<FlushResult> WriteAsync(HubMessage hubMessage, CancellationToken cancellationToken)
{
return WriteAsyncCore(hubMessage, _protocol.GetMessageBytes(hubMessage), cancellationToken);
return WriteAsyncCore(hubMessage.GetType(), _protocol.GetMessageBytes(hubMessage), cancellationToken);
}

private async ValueTask<FlushResult> WriteAsyncCore(HubMessage hubMessage, ReadOnlyMemory<byte> messageBytes, CancellationToken cancellationToken)
private async ValueTask<FlushResult> WriteAsyncCore(Type hubMessageType, ReadOnlyMemory<byte> messageBytes, CancellationToken cancellationToken)
{
// TODO: Add backpressure based on message count
if (_bufferedByteCount > _bufferLimit)
@@ -158,7 +159,7 @@ private async ValueTask<FlushResult> WriteAsyncCore(HubMessage hubMessage, ReadO
await _writeLock.WaitAsync(cancellationToken: default).ConfigureAwait(false);
try
{
if (hubMessage is HubInvocationMessage invocationMessage)
if (typeof(HubInvocationMessage).IsAssignableFrom(hubMessageType))
{
_totalMessageCount++;
_bufferedByteCount += messageBytes.Length;
3 changes: 3 additions & 0 deletions src/SignalR/server/Core/src/SerializedHubMessage.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.AspNetCore.SignalR;
@@ -40,6 +41,8 @@ public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
/// <param name="message">The hub message for the cache. This will be serialized with an <see cref="IHubProtocol"/> in <see cref="GetSerializedMessage"/> to get the message's serialized representation.</param>
public SerializedHubMessage(HubMessage message)
{
// Type currently only used for invocation messages, we should probably refactor it to be explicit about that e.g. new property for message type?
Debug.Assert(message.GetType().IsAssignableTo(typeof(HubInvocationMessage)));
Message = message;
}

Original file line number Diff line number Diff line change
@@ -2,11 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO.Pipelines;
using System.Text.Json;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Time.Testing;

@@ -169,6 +170,62 @@ public async Task UnAckedMessageResentOnReconnect()
Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
}

// Regression test for https://github.com/dotnet/aspnetcore/issues/55575
[Fact]
public async Task UnAckedSerializedMessageResentOnReconnect()
{
var protocol = new JsonHubProtocol();
var connection = new TestConnectionContext();
var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions());
connection.Transport = pipes.Transport;
using var messageBuffer = new MessageBuffer(connection, protocol, bufferLimit: 1000, NullLogger.Instance);

var invocationMessage = new SerializedHubMessage([new SerializedMessage(protocol.Name,
protocol.GetMessageBytes(new InvocationMessage("method1", [1])))]);
await messageBuffer.WriteAsync(invocationMessage, default);

var res = await pipes.Application.Input.ReadAsync();

var buffer = res.Buffer;
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message));
var parsedMessage = Assert.IsType<InvocationMessage>(message);
Assert.Equal("method1", parsedMessage.Target);
Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32());

pipes.Application.Input.AdvanceTo(buffer.Start);

DuplexPipe.UpdateConnectionPair(ref pipes, connection);
await messageBuffer.ResendAsync(pipes.Transport.Output);

Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance));
Assert.True(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
Assert.True(messageBuffer.ShouldProcessMessage(new SequenceMessage(1)));

res = await pipes.Application.Input.ReadAsync();

buffer = res.Buffer;
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message));
var seqMessage = Assert.IsType<SequenceMessage>(message);
Assert.Equal(1, seqMessage.SequenceId);

pipes.Application.Input.AdvanceTo(buffer.Start);

res = await pipes.Application.Input.ReadAsync();

buffer = res.Buffer;
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message));
parsedMessage = Assert.IsType<InvocationMessage>(message);
Assert.Equal("method1", parsedMessage.Target);
Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32());

pipes.Application.Input.AdvanceTo(buffer.Start);

messageBuffer.ShouldProcessMessage(new SequenceMessage(1));

Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance));
Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
}

[Fact]
public async Task AckedMessageNotResentOnReconnect()
{
173 changes: 167 additions & 6 deletions src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using System.Net.WebSockets;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Client;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Xunit;

namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests;

@@ -213,7 +211,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit
}
}

private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null)
[ConditionalTheory]
[SkipIfDockerNotPresent]
[InlineData("messagepack")]
[InlineData("json")]
public async Task StatefulReconnectPreservesMessageFromOtherServer(string protocolName)
{
using (StartVerifiableLog())
{
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);

ClientWebSocket innerWs = null;
WebSocketWrapper ws = null;
TaskCompletionSource reconnectTcs = null;
TaskCompletionSource startedReconnectTcs = null;

var connection = CreateConnection(_serverFixture.FirstServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory,
customizeConnection: builder =>
{
builder.WithStatefulReconnect();
builder.Services.Configure<HttpConnectionOptions>(o =>
{
// Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure
// Which will trigger the stateful reconnect flow
o.WebSocketFactory = async (context, token) =>
{
if (reconnectTcs is null)
{
reconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
startedReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}
else
{
startedReconnectTcs.SetResult();
// We only want to wait on the reconnect, not the initial connection attempt
await reconnectTcs.Task.DefaultTimeout();
}

innerWs = new ClientWebSocket();
ws = new WebSocketWrapper(innerWs);
await innerWs.ConnectAsync(context.Uri, token);

_ = Task.Run(async () =>
{
try
{
while (innerWs.State == WebSocketState.Open)
{
var buffer = new byte[1024];
var res = await innerWs.ReceiveAsync(buffer, default);
ws.SetReceiveResult((res, buffer.AsMemory(0, res.Count)));
}
}
// Log but ignore receive errors, that likely just means the connection closed
catch (Exception ex)
{
Logger.LogInformation(ex, "Error while reading from inner websocket");
}
});

return ws;
};
});
});
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory);

var tcs = new TaskCompletionSource<string>();
connection.On<string>("SendToAll", message => tcs.TrySetResult(message));

var tcs2 = new TaskCompletionSource<string>();
secondConnection.On<string>("SendToAll", message => tcs2.TrySetResult(message));

await connection.StartAsync().DefaultTimeout();
await secondConnection.StartAsync().DefaultTimeout();

// Close first connection before the second connection sends a message to all clients
await ws.CloseOutputAsync(WebSocketCloseStatus.InternalServerError, statusDescription: null, default);
await startedReconnectTcs.Task.DefaultTimeout();

// Send to all clients, since both clients are on different servers this means the backplane will be used
// And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects
// But is on a different server from the original message sender.
await secondConnection.SendAsync("SendToAll", "test message").DefaultTimeout();

// Check that second connection still receives the message
Assert.Equal("test message", await tcs2.Task.DefaultTimeout());
Assert.False(tcs.Task.IsCompleted);

// allow first connection to reconnect
reconnectTcs.SetResult();

// Check that first connection received the message once it reconnected
Assert.Equal("test message", await tcs.Task.DefaultTimeout());

await connection.DisposeAsync().DefaultTimeout();
}
}

private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null,
Action<IHubConnectionBuilder> customizeConnection = null)
{
var hubConnectionBuilder = new HubConnectionBuilder()
.WithLoggerFactory(loggerFactory)
@@ -227,6 +323,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran

hubConnectionBuilder.Services.AddSingleton(protocol);

customizeConnection?.Invoke(hubConnectionBuilder);

return hubConnectionBuilder.Build();
}

@@ -255,4 +353,67 @@ public static IEnumerable<object[]> TransportTypesAndProtocolTypes
}
}
}

internal sealed class WebSocketWrapper : WebSocket
{
private readonly WebSocket _inner;
private TaskCompletionSource<(WebSocketReceiveResult, ReadOnlyMemory<byte>)> _receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

public WebSocketWrapper(WebSocket inner)
{
_inner = inner;
}

public override WebSocketCloseStatus? CloseStatus => _inner.CloseStatus;

public override string CloseStatusDescription => _inner.CloseStatusDescription;

public override WebSocketState State => _inner.State;

public override string SubProtocol => _inner.SubProtocol;

public override void Abort()
{
_inner.Abort();
}

public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
return _inner.CloseAsync(closeStatus, statusDescription, cancellationToken);
}

public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
_receiveTcs.TrySetException(new IOException("force reconnect"));
return Task.CompletedTask;
}

public override void Dispose()
{
_inner.Dispose();
}

public void SetReceiveResult((WebSocketReceiveResult, ReadOnlyMemory<byte>) result)
{
_receiveTcs.SetResult(result);
}

public override async Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
{
var res = await _receiveTcs.Task;
// Handle zero-byte reads
if (buffer.Count == 0)
{
return res.Item1;
}
_receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
res.Item2.CopyTo(buffer);
return res.Item1;
}

public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
return _inner.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
}
}
}
1 change: 1 addition & 0 deletions src/SignalR/server/StackExchangeRedis/test/Startup.cs
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ public void Configure(IApplicationBuilder app)
app.UseEndpoints(endpoints =>
{
endpoints.MapHub<EchoHub>("/echo");
endpoints.MapHub<StatefulHub>("/stateful", o => o.AllowStatefulReconnects = true);
});
}

12 changes: 12 additions & 0 deletions src/SignalR/server/StackExchangeRedis/test/StatefulHub.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests;

public class StatefulHub : Hub
{
public Task SendToAll(string message)
{
return Clients.All.SendAsync("SendToAll", message);
}
}