Skip to content

Commit fdcc5b2

Browse files
committed
Fix preserving messages for stateful reconnect with backplane (#60900)
1 parent 3595ffc commit fdcc5b2

File tree

6 files changed

+244
-9
lines changed

6 files changed

+244
-9
lines changed

src/SignalR/common/Shared/MessageBuffer.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,16 @@ private async Task RunTimer()
121121

122122
public ValueTask<FlushResult> WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken)
123123
{
124-
return WriteAsyncCore(hubMessage.Message!, hubMessage.GetSerializedMessage(_protocol), cancellationToken);
124+
// Default to HubInvocationMessage as that's the only type we use SerializedHubMessage for currently when Message is null. Should harden this in the future.
125+
return WriteAsyncCore(hubMessage.Message?.GetType() ?? typeof(HubInvocationMessage), hubMessage.GetSerializedMessage(_protocol), cancellationToken);
125126
}
126127

127128
public ValueTask<FlushResult> WriteAsync(HubMessage hubMessage, CancellationToken cancellationToken)
128129
{
129-
return WriteAsyncCore(hubMessage, _protocol.GetMessageBytes(hubMessage), cancellationToken);
130+
return WriteAsyncCore(hubMessage.GetType(), _protocol.GetMessageBytes(hubMessage), cancellationToken);
130131
}
131132

132-
private async ValueTask<FlushResult> WriteAsyncCore(HubMessage hubMessage, ReadOnlyMemory<byte> messageBytes, CancellationToken cancellationToken)
133+
private async ValueTask<FlushResult> WriteAsyncCore(Type hubMessageType, ReadOnlyMemory<byte> messageBytes, CancellationToken cancellationToken)
133134
{
134135
// TODO: Add backpressure based on message count
135136
if (_bufferedByteCount > _bufferLimit)
@@ -158,7 +159,7 @@ private async ValueTask<FlushResult> WriteAsyncCore(HubMessage hubMessage, ReadO
158159
await _writeLock.WaitAsync(cancellationToken: default).ConfigureAwait(false);
159160
try
160161
{
161-
if (hubMessage is HubInvocationMessage invocationMessage)
162+
if (typeof(HubInvocationMessage).IsAssignableFrom(hubMessageType))
162163
{
163164
_totalMessageCount++;
164165
_bufferedByteCount += messageBytes.Length;

src/SignalR/server/Core/src/SerializedHubMessage.cs

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Diagnostics;
45
using Microsoft.AspNetCore.SignalR.Protocol;
56

67
namespace Microsoft.AspNetCore.SignalR;
@@ -40,6 +41,8 @@ public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
4041
/// <param name="message">The hub message for the cache. This will be serialized with an <see cref="IHubProtocol"/> in <see cref="GetSerializedMessage"/> to get the message's serialized representation.</param>
4142
public SerializedHubMessage(HubMessage message)
4243
{
44+
// Type currently only used for invocation messages, we should probably refactor it to be explicit about that e.g. new property for message type?
45+
Debug.Assert(message.GetType().IsAssignableTo(typeof(HubInvocationMessage)));
4346
Message = message;
4447
}
4548

src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs

+57
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.IO.Pipelines;
5+
using System.Text.Json;
56
using Microsoft.AspNetCore.Connections;
67
using Microsoft.AspNetCore.Http.Features;
78
using Microsoft.AspNetCore.SignalR.Internal;
@@ -169,6 +170,62 @@ public async Task UnAckedMessageResentOnReconnect()
169170
Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
170171
}
171172

173+
// Regression test for https://github.com/dotnet/aspnetcore/issues/55575
174+
[Fact]
175+
public async Task UnAckedSerializedMessageResentOnReconnect()
176+
{
177+
var protocol = new JsonHubProtocol();
178+
var connection = new TestConnectionContext();
179+
var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions());
180+
connection.Transport = pipes.Transport;
181+
using var messageBuffer = new MessageBuffer(connection, protocol, bufferLimit: 1000, NullLogger.Instance);
182+
183+
var invocationMessage = new SerializedHubMessage([new SerializedMessage(protocol.Name,
184+
protocol.GetMessageBytes(new InvocationMessage("method1", [1])))]);
185+
await messageBuffer.WriteAsync(invocationMessage, default);
186+
187+
var res = await pipes.Application.Input.ReadAsync();
188+
189+
var buffer = res.Buffer;
190+
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message));
191+
var parsedMessage = Assert.IsType<InvocationMessage>(message);
192+
Assert.Equal("method1", parsedMessage.Target);
193+
Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32());
194+
195+
pipes.Application.Input.AdvanceTo(buffer.Start);
196+
197+
DuplexPipe.UpdateConnectionPair(ref pipes, connection);
198+
await messageBuffer.ResendAsync(pipes.Transport.Output);
199+
200+
Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance));
201+
Assert.True(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
202+
Assert.True(messageBuffer.ShouldProcessMessage(new SequenceMessage(1)));
203+
204+
res = await pipes.Application.Input.ReadAsync();
205+
206+
buffer = res.Buffer;
207+
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message));
208+
var seqMessage = Assert.IsType<SequenceMessage>(message);
209+
Assert.Equal(1, seqMessage.SequenceId);
210+
211+
pipes.Application.Input.AdvanceTo(buffer.Start);
212+
213+
res = await pipes.Application.Input.ReadAsync();
214+
215+
buffer = res.Buffer;
216+
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message));
217+
parsedMessage = Assert.IsType<InvocationMessage>(message);
218+
Assert.Equal("method1", parsedMessage.Target);
219+
Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32());
220+
221+
pipes.Application.Input.AdvanceTo(buffer.Start);
222+
223+
messageBuffer.ShouldProcessMessage(new SequenceMessage(1));
224+
225+
Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance));
226+
Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
227+
}
228+
172229
[Fact]
173230
public async Task AckedMessageNotResentOnReconnect()
174231
{

src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs

+166-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System;
5-
using System.Collections.Generic;
6-
using System.Threading.Tasks;
4+
using System.Net.WebSockets;
75
using Microsoft.AspNetCore.Http.Connections;
6+
using Microsoft.AspNetCore.Http.Connections.Client;
87
using Microsoft.AspNetCore.SignalR.Client;
98
using Microsoft.AspNetCore.SignalR.Protocol;
109
using Microsoft.AspNetCore.SignalR.Tests;
1110
using Microsoft.AspNetCore.Testing;
1211
using Microsoft.Extensions.DependencyInjection;
1312
using Microsoft.Extensions.Logging;
14-
using Xunit;
1513

1614
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests;
1715

@@ -211,7 +209,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit
211209
}
212210
}
213211

214-
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null)
212+
[ConditionalTheory]
213+
[SkipIfDockerNotPresent]
214+
[InlineData("messagepack")]
215+
[InlineData("json")]
216+
public async Task StatefulReconnectPreservesMessageFromOtherServer(string protocolName)
217+
{
218+
using (StartVerifiableLog())
219+
{
220+
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
221+
222+
ClientWebSocket innerWs = null;
223+
WebSocketWrapper ws = null;
224+
TaskCompletionSource reconnectTcs = null;
225+
TaskCompletionSource startedReconnectTcs = null;
226+
227+
var connection = CreateConnection(_serverFixture.FirstServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory,
228+
customizeConnection: builder =>
229+
{
230+
builder.WithStatefulReconnect();
231+
builder.Services.Configure<HttpConnectionOptions>(o =>
232+
{
233+
// Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure
234+
// Which will trigger the stateful reconnect flow
235+
o.WebSocketFactory = async (context, token) =>
236+
{
237+
if (reconnectTcs is null)
238+
{
239+
reconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
240+
startedReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
241+
}
242+
else
243+
{
244+
startedReconnectTcs.SetResult();
245+
// We only want to wait on the reconnect, not the initial connection attempt
246+
await reconnectTcs.Task.DefaultTimeout();
247+
}
248+
249+
innerWs = new ClientWebSocket();
250+
ws = new WebSocketWrapper(innerWs);
251+
await innerWs.ConnectAsync(context.Uri, token);
252+
253+
_ = Task.Run(async () =>
254+
{
255+
try
256+
{
257+
while (innerWs.State == WebSocketState.Open)
258+
{
259+
var buffer = new byte[1024];
260+
var res = await innerWs.ReceiveAsync(buffer, default);
261+
ws.SetReceiveResult((res, buffer.AsMemory(0, res.Count)));
262+
}
263+
}
264+
// Log but ignore receive errors, that likely just means the connection closed
265+
catch (Exception ex)
266+
{
267+
Logger.LogInformation(ex, "Error while reading from inner websocket");
268+
}
269+
});
270+
271+
return ws;
272+
};
273+
});
274+
});
275+
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory);
276+
277+
var tcs = new TaskCompletionSource<string>();
278+
connection.On<string>("SendToAll", message => tcs.TrySetResult(message));
279+
280+
var tcs2 = new TaskCompletionSource<string>();
281+
secondConnection.On<string>("SendToAll", message => tcs2.TrySetResult(message));
282+
283+
await connection.StartAsync().DefaultTimeout();
284+
await secondConnection.StartAsync().DefaultTimeout();
285+
286+
// Close first connection before the second connection sends a message to all clients
287+
await ws.CloseOutputAsync(WebSocketCloseStatus.InternalServerError, statusDescription: null, default);
288+
await startedReconnectTcs.Task.DefaultTimeout();
289+
290+
// Send to all clients, since both clients are on different servers this means the backplane will be used
291+
// And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects
292+
// But is on a different server from the original message sender.
293+
await secondConnection.SendAsync("SendToAll", "test message").DefaultTimeout();
294+
295+
// Check that second connection still receives the message
296+
Assert.Equal("test message", await tcs2.Task.DefaultTimeout());
297+
Assert.False(tcs.Task.IsCompleted);
298+
299+
// allow first connection to reconnect
300+
reconnectTcs.SetResult();
301+
302+
// Check that first connection received the message once it reconnected
303+
Assert.Equal("test message", await tcs.Task.DefaultTimeout());
304+
305+
await connection.DisposeAsync().DefaultTimeout();
306+
}
307+
}
308+
309+
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null,
310+
Action<IHubConnectionBuilder> customizeConnection = null)
215311
{
216312
var hubConnectionBuilder = new HubConnectionBuilder()
217313
.WithLoggerFactory(loggerFactory)
@@ -225,6 +321,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran
225321

226322
hubConnectionBuilder.Services.AddSingleton(protocol);
227323

324+
customizeConnection?.Invoke(hubConnectionBuilder);
325+
228326
return hubConnectionBuilder.Build();
229327
}
230328

@@ -253,4 +351,67 @@ public static IEnumerable<object[]> TransportTypesAndProtocolTypes
253351
}
254352
}
255353
}
354+
355+
internal sealed class WebSocketWrapper : WebSocket
356+
{
357+
private readonly WebSocket _inner;
358+
private TaskCompletionSource<(WebSocketReceiveResult, ReadOnlyMemory<byte>)> _receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
359+
360+
public WebSocketWrapper(WebSocket inner)
361+
{
362+
_inner = inner;
363+
}
364+
365+
public override WebSocketCloseStatus? CloseStatus => _inner.CloseStatus;
366+
367+
public override string CloseStatusDescription => _inner.CloseStatusDescription;
368+
369+
public override WebSocketState State => _inner.State;
370+
371+
public override string SubProtocol => _inner.SubProtocol;
372+
373+
public override void Abort()
374+
{
375+
_inner.Abort();
376+
}
377+
378+
public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
379+
{
380+
return _inner.CloseAsync(closeStatus, statusDescription, cancellationToken);
381+
}
382+
383+
public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
384+
{
385+
_receiveTcs.TrySetException(new IOException("force reconnect"));
386+
return Task.CompletedTask;
387+
}
388+
389+
public override void Dispose()
390+
{
391+
_inner.Dispose();
392+
}
393+
394+
public void SetReceiveResult((WebSocketReceiveResult, ReadOnlyMemory<byte>) result)
395+
{
396+
_receiveTcs.SetResult(result);
397+
}
398+
399+
public override async Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
400+
{
401+
var res = await _receiveTcs.Task;
402+
// Handle zero-byte reads
403+
if (buffer.Count == 0)
404+
{
405+
return res.Item1;
406+
}
407+
_receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
408+
res.Item2.CopyTo(buffer);
409+
return res.Item1;
410+
}
411+
412+
public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
413+
{
414+
return _inner.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
415+
}
416+
}
256417
}

src/SignalR/server/StackExchangeRedis/test/Startup.cs

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public void Configure(IApplicationBuilder app)
3333
app.UseEndpoints(endpoints =>
3434
{
3535
endpoints.MapHub<EchoHub>("/echo");
36+
endpoints.MapHub<StatefulHub>("/stateful", o => o.AllowStatefulReconnects = true);
3637
});
3738
}
3839

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests;
5+
6+
public class StatefulHub : Hub
7+
{
8+
public Task SendToAll(string message)
9+
{
10+
return Clients.All.SendAsync("SendToAll", message);
11+
}
12+
}

0 commit comments

Comments
 (0)