Skip to content

Commit 8491a27

Browse files
Fix preserving messages for stateful reconnect with backplane (#60900)
1 parent 2a723be commit 8491a27

File tree

6 files changed

+246
-11
lines changed

6 files changed

+246
-11
lines changed

src/SignalR/common/Shared/MessageBuffer.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,16 @@ private async Task RunTimer()
121121

122122
public ValueTask<FlushResult> WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken)
123123
{
124-
return WriteAsyncCore(hubMessage.Message!, hubMessage.GetSerializedMessage(_protocol), cancellationToken);
124+
// Default to HubInvocationMessage as that's the only type we use SerializedHubMessage for currently when Message is null. Should harden this in the future.
125+
return WriteAsyncCore(hubMessage.Message?.GetType() ?? typeof(HubInvocationMessage), hubMessage.GetSerializedMessage(_protocol), cancellationToken);
125126
}
126127

127128
public ValueTask<FlushResult> WriteAsync(HubMessage hubMessage, CancellationToken cancellationToken)
128129
{
129-
return WriteAsyncCore(hubMessage, _protocol.GetMessageBytes(hubMessage), cancellationToken);
130+
return WriteAsyncCore(hubMessage.GetType(), _protocol.GetMessageBytes(hubMessage), cancellationToken);
130131
}
131132

132-
private async ValueTask<FlushResult> WriteAsyncCore(HubMessage hubMessage, ReadOnlyMemory<byte> messageBytes, CancellationToken cancellationToken)
133+
private async ValueTask<FlushResult> WriteAsyncCore(Type hubMessageType, ReadOnlyMemory<byte> messageBytes, CancellationToken cancellationToken)
133134
{
134135
// TODO: Add backpressure based on message count
135136
if (_bufferedByteCount > _bufferLimit)
@@ -158,7 +159,7 @@ private async ValueTask<FlushResult> WriteAsyncCore(HubMessage hubMessage, ReadO
158159
await _writeLock.WaitAsync(cancellationToken: default).ConfigureAwait(false);
159160
try
160161
{
161-
if (hubMessage is HubInvocationMessage invocationMessage)
162+
if (typeof(HubInvocationMessage).IsAssignableFrom(hubMessageType))
162163
{
163164
_totalMessageCount++;
164165
_bufferedByteCount += messageBytes.Length;

src/SignalR/server/Core/src/SerializedHubMessage.cs

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Diagnostics;
45
using Microsoft.AspNetCore.SignalR.Protocol;
56

67
namespace Microsoft.AspNetCore.SignalR;
@@ -40,6 +41,8 @@ public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
4041
/// <param name="message">The hub message for the cache. This will be serialized with an <see cref="IHubProtocol"/> in <see cref="GetSerializedMessage"/> to get the message's serialized representation.</param>
4142
public SerializedHubMessage(HubMessage message)
4243
{
44+
// Type currently only used for invocation messages, we should probably refactor it to be explicit about that e.g. new property for message type?
45+
Debug.Assert(message.GetType().IsAssignableTo(typeof(HubInvocationMessage)));
4346
Message = message;
4447
}
4548

src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/Internal/MessageBufferTests.cs

+58-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.IO.Pipelines;
5+
using System.Text.Json;
56
using Microsoft.AspNetCore.Connections;
67
using Microsoft.AspNetCore.Http.Features;
8+
using Microsoft.AspNetCore.InternalTesting;
79
using Microsoft.AspNetCore.SignalR.Internal;
810
using Microsoft.AspNetCore.SignalR.Protocol;
9-
using Microsoft.AspNetCore.InternalTesting;
1011
using Microsoft.Extensions.Logging.Abstractions;
1112
using Microsoft.Extensions.Time.Testing;
1213

@@ -169,6 +170,62 @@ public async Task UnAckedMessageResentOnReconnect()
169170
Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
170171
}
171172

173+
// Regression test for https://github.com/dotnet/aspnetcore/issues/55575
174+
[Fact]
175+
public async Task UnAckedSerializedMessageResentOnReconnect()
176+
{
177+
var protocol = new JsonHubProtocol();
178+
var connection = new TestConnectionContext();
179+
var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions());
180+
connection.Transport = pipes.Transport;
181+
using var messageBuffer = new MessageBuffer(connection, protocol, bufferLimit: 1000, NullLogger.Instance);
182+
183+
var invocationMessage = new SerializedHubMessage([new SerializedMessage(protocol.Name,
184+
protocol.GetMessageBytes(new InvocationMessage("method1", [1])))]);
185+
await messageBuffer.WriteAsync(invocationMessage, default);
186+
187+
var res = await pipes.Application.Input.ReadAsync();
188+
189+
var buffer = res.Buffer;
190+
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message));
191+
var parsedMessage = Assert.IsType<InvocationMessage>(message);
192+
Assert.Equal("method1", parsedMessage.Target);
193+
Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32());
194+
195+
pipes.Application.Input.AdvanceTo(buffer.Start);
196+
197+
DuplexPipe.UpdateConnectionPair(ref pipes, connection);
198+
await messageBuffer.ResendAsync(pipes.Transport.Output);
199+
200+
Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance));
201+
Assert.True(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
202+
Assert.True(messageBuffer.ShouldProcessMessage(new SequenceMessage(1)));
203+
204+
res = await pipes.Application.Input.ReadAsync();
205+
206+
buffer = res.Buffer;
207+
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message));
208+
var seqMessage = Assert.IsType<SequenceMessage>(message);
209+
Assert.Equal(1, seqMessage.SequenceId);
210+
211+
pipes.Application.Input.AdvanceTo(buffer.Start);
212+
213+
res = await pipes.Application.Input.ReadAsync();
214+
215+
buffer = res.Buffer;
216+
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message));
217+
parsedMessage = Assert.IsType<InvocationMessage>(message);
218+
Assert.Equal("method1", parsedMessage.Target);
219+
Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32());
220+
221+
pipes.Application.Input.AdvanceTo(buffer.Start);
222+
223+
messageBuffer.ShouldProcessMessage(new SequenceMessage(1));
224+
225+
Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance));
226+
Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
227+
}
228+
172229
[Fact]
173230
public async Task AckedMessageNotResentOnReconnect()
174231
{

src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs

+167-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System;
5-
using System.Collections.Generic;
6-
using System.Threading.Tasks;
4+
using System.Net.WebSockets;
75
using Microsoft.AspNetCore.Http.Connections;
6+
using Microsoft.AspNetCore.Http.Connections.Client;
7+
using Microsoft.AspNetCore.InternalTesting;
88
using Microsoft.AspNetCore.SignalR.Client;
99
using Microsoft.AspNetCore.SignalR.Protocol;
1010
using Microsoft.AspNetCore.SignalR.Tests;
11-
using Microsoft.AspNetCore.InternalTesting;
1211
using Microsoft.Extensions.DependencyInjection;
1312
using Microsoft.Extensions.Logging;
14-
using Xunit;
1513

1614
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests;
1715

@@ -213,7 +211,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit
213211
}
214212
}
215213

216-
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null)
214+
[ConditionalTheory]
215+
[SkipIfDockerNotPresent]
216+
[InlineData("messagepack")]
217+
[InlineData("json")]
218+
public async Task StatefulReconnectPreservesMessageFromOtherServer(string protocolName)
219+
{
220+
using (StartVerifiableLog())
221+
{
222+
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
223+
224+
ClientWebSocket innerWs = null;
225+
WebSocketWrapper ws = null;
226+
TaskCompletionSource reconnectTcs = null;
227+
TaskCompletionSource startedReconnectTcs = null;
228+
229+
var connection = CreateConnection(_serverFixture.FirstServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory,
230+
customizeConnection: builder =>
231+
{
232+
builder.WithStatefulReconnect();
233+
builder.Services.Configure<HttpConnectionOptions>(o =>
234+
{
235+
// Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure
236+
// Which will trigger the stateful reconnect flow
237+
o.WebSocketFactory = async (context, token) =>
238+
{
239+
if (reconnectTcs is null)
240+
{
241+
reconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
242+
startedReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
243+
}
244+
else
245+
{
246+
startedReconnectTcs.SetResult();
247+
// We only want to wait on the reconnect, not the initial connection attempt
248+
await reconnectTcs.Task.DefaultTimeout();
249+
}
250+
251+
innerWs = new ClientWebSocket();
252+
ws = new WebSocketWrapper(innerWs);
253+
await innerWs.ConnectAsync(context.Uri, token);
254+
255+
_ = Task.Run(async () =>
256+
{
257+
try
258+
{
259+
while (innerWs.State == WebSocketState.Open)
260+
{
261+
var buffer = new byte[1024];
262+
var res = await innerWs.ReceiveAsync(buffer, default);
263+
ws.SetReceiveResult((res, buffer.AsMemory(0, res.Count)));
264+
}
265+
}
266+
// Log but ignore receive errors, that likely just means the connection closed
267+
catch (Exception ex)
268+
{
269+
Logger.LogInformation(ex, "Error while reading from inner websocket");
270+
}
271+
});
272+
273+
return ws;
274+
};
275+
});
276+
});
277+
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory);
278+
279+
var tcs = new TaskCompletionSource<string>();
280+
connection.On<string>("SendToAll", message => tcs.TrySetResult(message));
281+
282+
var tcs2 = new TaskCompletionSource<string>();
283+
secondConnection.On<string>("SendToAll", message => tcs2.TrySetResult(message));
284+
285+
await connection.StartAsync().DefaultTimeout();
286+
await secondConnection.StartAsync().DefaultTimeout();
287+
288+
// Close first connection before the second connection sends a message to all clients
289+
await ws.CloseOutputAsync(WebSocketCloseStatus.InternalServerError, statusDescription: null, default);
290+
await startedReconnectTcs.Task.DefaultTimeout();
291+
292+
// Send to all clients, since both clients are on different servers this means the backplane will be used
293+
// And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects
294+
// But is on a different server from the original message sender.
295+
await secondConnection.SendAsync("SendToAll", "test message").DefaultTimeout();
296+
297+
// Check that second connection still receives the message
298+
Assert.Equal("test message", await tcs2.Task.DefaultTimeout());
299+
Assert.False(tcs.Task.IsCompleted);
300+
301+
// allow first connection to reconnect
302+
reconnectTcs.SetResult();
303+
304+
// Check that first connection received the message once it reconnected
305+
Assert.Equal("test message", await tcs.Task.DefaultTimeout());
306+
307+
await connection.DisposeAsync().DefaultTimeout();
308+
}
309+
}
310+
311+
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null,
312+
Action<IHubConnectionBuilder> customizeConnection = null)
217313
{
218314
var hubConnectionBuilder = new HubConnectionBuilder()
219315
.WithLoggerFactory(loggerFactory)
@@ -227,6 +323,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran
227323

228324
hubConnectionBuilder.Services.AddSingleton(protocol);
229325

326+
customizeConnection?.Invoke(hubConnectionBuilder);
327+
230328
return hubConnectionBuilder.Build();
231329
}
232330

@@ -255,4 +353,67 @@ public static IEnumerable<object[]> TransportTypesAndProtocolTypes
255353
}
256354
}
257355
}
356+
357+
internal sealed class WebSocketWrapper : WebSocket
358+
{
359+
private readonly WebSocket _inner;
360+
private TaskCompletionSource<(WebSocketReceiveResult, ReadOnlyMemory<byte>)> _receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
361+
362+
public WebSocketWrapper(WebSocket inner)
363+
{
364+
_inner = inner;
365+
}
366+
367+
public override WebSocketCloseStatus? CloseStatus => _inner.CloseStatus;
368+
369+
public override string CloseStatusDescription => _inner.CloseStatusDescription;
370+
371+
public override WebSocketState State => _inner.State;
372+
373+
public override string SubProtocol => _inner.SubProtocol;
374+
375+
public override void Abort()
376+
{
377+
_inner.Abort();
378+
}
379+
380+
public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
381+
{
382+
return _inner.CloseAsync(closeStatus, statusDescription, cancellationToken);
383+
}
384+
385+
public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
386+
{
387+
_receiveTcs.TrySetException(new IOException("force reconnect"));
388+
return Task.CompletedTask;
389+
}
390+
391+
public override void Dispose()
392+
{
393+
_inner.Dispose();
394+
}
395+
396+
public void SetReceiveResult((WebSocketReceiveResult, ReadOnlyMemory<byte>) result)
397+
{
398+
_receiveTcs.SetResult(result);
399+
}
400+
401+
public override async Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
402+
{
403+
var res = await _receiveTcs.Task;
404+
// Handle zero-byte reads
405+
if (buffer.Count == 0)
406+
{
407+
return res.Item1;
408+
}
409+
_receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
410+
res.Item2.CopyTo(buffer);
411+
return res.Item1;
412+
}
413+
414+
public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
415+
{
416+
return _inner.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
417+
}
418+
}
258419
}

src/SignalR/server/StackExchangeRedis/test/Startup.cs

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public void Configure(IApplicationBuilder app)
3333
app.UseEndpoints(endpoints =>
3434
{
3535
endpoints.MapHub<EchoHub>("/echo");
36+
endpoints.MapHub<StatefulHub>("/stateful", o => o.AllowStatefulReconnects = true);
3637
});
3738
}
3839

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests;
5+
6+
public class StatefulHub : Hub
7+
{
8+
public Task SendToAll(string message)
9+
{
10+
return Clients.All.SendAsync("SendToAll", message);
11+
}
12+
}

0 commit comments

Comments
 (0)