1
1
// Licensed to the .NET Foundation under one or more agreements.
2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
4
- using System ;
5
- using System . Collections . Generic ;
6
- using System . Threading . Tasks ;
4
+ using System . Net . WebSockets ;
7
5
using Microsoft . AspNetCore . Http . Connections ;
6
+ using Microsoft . AspNetCore . Http . Connections . Client ;
7
+ using Microsoft . AspNetCore . InternalTesting ;
8
8
using Microsoft . AspNetCore . SignalR . Client ;
9
9
using Microsoft . AspNetCore . SignalR . Protocol ;
10
10
using Microsoft . AspNetCore . SignalR . Tests ;
11
- using Microsoft . AspNetCore . InternalTesting ;
12
11
using Microsoft . Extensions . DependencyInjection ;
13
12
using Microsoft . Extensions . Logging ;
14
- using Xunit ;
15
13
16
14
namespace Microsoft . AspNetCore . SignalR . StackExchangeRedis . Tests ;
17
15
@@ -213,7 +211,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit
213
211
}
214
212
}
215
213
216
- private static HubConnection CreateConnection ( string url , HttpTransportType transportType , IHubProtocol protocol , ILoggerFactory loggerFactory , string userName = null )
214
+ [ ConditionalTheory ]
215
+ [ SkipIfDockerNotPresent ]
216
+ [ InlineData ( "messagepack" ) ]
217
+ [ InlineData ( "json" ) ]
218
+ public async Task StatefulReconnectPreservesMessageFromOtherServer ( string protocolName )
219
+ {
220
+ using ( StartVerifiableLog ( ) )
221
+ {
222
+ var protocol = HubProtocolHelpers . GetHubProtocol ( protocolName ) ;
223
+
224
+ ClientWebSocket innerWs = null ;
225
+ WebSocketWrapper ws = null ;
226
+ TaskCompletionSource reconnectTcs = null ;
227
+ TaskCompletionSource startedReconnectTcs = null ;
228
+
229
+ var connection = CreateConnection ( _serverFixture . FirstServer . Url + "/stateful" , HttpTransportType . WebSockets , protocol , LoggerFactory ,
230
+ customizeConnection : builder =>
231
+ {
232
+ builder . WithStatefulReconnect ( ) ;
233
+ builder . Services . Configure < HttpConnectionOptions > ( o =>
234
+ {
235
+ // Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure
236
+ // Which will trigger the stateful reconnect flow
237
+ o . WebSocketFactory = async ( context , token ) =>
238
+ {
239
+ if ( reconnectTcs is null )
240
+ {
241
+ reconnectTcs = new TaskCompletionSource ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
242
+ startedReconnectTcs = new TaskCompletionSource ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
243
+ }
244
+ else
245
+ {
246
+ startedReconnectTcs . SetResult ( ) ;
247
+ // We only want to wait on the reconnect, not the initial connection attempt
248
+ await reconnectTcs . Task . DefaultTimeout ( ) ;
249
+ }
250
+
251
+ innerWs = new ClientWebSocket ( ) ;
252
+ ws = new WebSocketWrapper ( innerWs ) ;
253
+ await innerWs . ConnectAsync ( context . Uri , token ) ;
254
+
255
+ _ = Task . Run ( async ( ) =>
256
+ {
257
+ try
258
+ {
259
+ while ( innerWs . State == WebSocketState . Open )
260
+ {
261
+ var buffer = new byte [ 1024 ] ;
262
+ var res = await innerWs . ReceiveAsync ( buffer , default ) ;
263
+ ws . SetReceiveResult ( ( res , buffer . AsMemory ( 0 , res . Count ) ) ) ;
264
+ }
265
+ }
266
+ // Log but ignore receive errors, that likely just means the connection closed
267
+ catch ( Exception ex )
268
+ {
269
+ Logger . LogInformation ( ex , "Error while reading from inner websocket" ) ;
270
+ }
271
+ } ) ;
272
+
273
+ return ws ;
274
+ } ;
275
+ } ) ;
276
+ } ) ;
277
+ var secondConnection = CreateConnection ( _serverFixture . SecondServer . Url + "/stateful" , HttpTransportType . WebSockets , protocol , LoggerFactory ) ;
278
+
279
+ var tcs = new TaskCompletionSource < string > ( ) ;
280
+ connection . On < string > ( "SendToAll" , message => tcs . TrySetResult ( message ) ) ;
281
+
282
+ var tcs2 = new TaskCompletionSource < string > ( ) ;
283
+ secondConnection . On < string > ( "SendToAll" , message => tcs2 . TrySetResult ( message ) ) ;
284
+
285
+ await connection . StartAsync ( ) . DefaultTimeout ( ) ;
286
+ await secondConnection . StartAsync ( ) . DefaultTimeout ( ) ;
287
+
288
+ // Close first connection before the second connection sends a message to all clients
289
+ await ws . CloseOutputAsync ( WebSocketCloseStatus . InternalServerError , statusDescription : null , default ) ;
290
+ await startedReconnectTcs . Task . DefaultTimeout ( ) ;
291
+
292
+ // Send to all clients, since both clients are on different servers this means the backplane will be used
293
+ // And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects
294
+ // But is on a different server from the original message sender.
295
+ await secondConnection . SendAsync ( "SendToAll" , "test message" ) . DefaultTimeout ( ) ;
296
+
297
+ // Check that second connection still receives the message
298
+ Assert . Equal ( "test message" , await tcs2 . Task . DefaultTimeout ( ) ) ;
299
+ Assert . False ( tcs . Task . IsCompleted ) ;
300
+
301
+ // allow first connection to reconnect
302
+ reconnectTcs . SetResult ( ) ;
303
+
304
+ // Check that first connection received the message once it reconnected
305
+ Assert . Equal ( "test message" , await tcs . Task . DefaultTimeout ( ) ) ;
306
+
307
+ await connection . DisposeAsync ( ) . DefaultTimeout ( ) ;
308
+ }
309
+ }
310
+
311
+ private static HubConnection CreateConnection ( string url , HttpTransportType transportType , IHubProtocol protocol , ILoggerFactory loggerFactory , string userName = null ,
312
+ Action < IHubConnectionBuilder > customizeConnection = null )
217
313
{
218
314
var hubConnectionBuilder = new HubConnectionBuilder ( )
219
315
. WithLoggerFactory ( loggerFactory )
@@ -227,6 +323,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran
227
323
228
324
hubConnectionBuilder . Services . AddSingleton ( protocol ) ;
229
325
326
+ customizeConnection ? . Invoke ( hubConnectionBuilder ) ;
327
+
230
328
return hubConnectionBuilder . Build ( ) ;
231
329
}
232
330
@@ -255,4 +353,67 @@ public static IEnumerable<object[]> TransportTypesAndProtocolTypes
255
353
}
256
354
}
257
355
}
356
+
357
+ internal sealed class WebSocketWrapper : WebSocket
358
+ {
359
+ private readonly WebSocket _inner ;
360
+ private TaskCompletionSource < ( WebSocketReceiveResult , ReadOnlyMemory < byte > ) > _receiveTcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
361
+
362
+ public WebSocketWrapper ( WebSocket inner )
363
+ {
364
+ _inner = inner ;
365
+ }
366
+
367
+ public override WebSocketCloseStatus ? CloseStatus => _inner . CloseStatus ;
368
+
369
+ public override string CloseStatusDescription => _inner . CloseStatusDescription ;
370
+
371
+ public override WebSocketState State => _inner . State ;
372
+
373
+ public override string SubProtocol => _inner . SubProtocol ;
374
+
375
+ public override void Abort ( )
376
+ {
377
+ _inner . Abort ( ) ;
378
+ }
379
+
380
+ public override Task CloseAsync ( WebSocketCloseStatus closeStatus , string statusDescription , CancellationToken cancellationToken )
381
+ {
382
+ return _inner . CloseAsync ( closeStatus , statusDescription , cancellationToken ) ;
383
+ }
384
+
385
+ public override Task CloseOutputAsync ( WebSocketCloseStatus closeStatus , string statusDescription , CancellationToken cancellationToken )
386
+ {
387
+ _receiveTcs . TrySetException ( new IOException ( "force reconnect" ) ) ;
388
+ return Task . CompletedTask ;
389
+ }
390
+
391
+ public override void Dispose ( )
392
+ {
393
+ _inner . Dispose ( ) ;
394
+ }
395
+
396
+ public void SetReceiveResult ( ( WebSocketReceiveResult , ReadOnlyMemory < byte > ) result )
397
+ {
398
+ _receiveTcs . SetResult ( result ) ;
399
+ }
400
+
401
+ public override async Task < WebSocketReceiveResult > ReceiveAsync ( ArraySegment < byte > buffer , CancellationToken cancellationToken )
402
+ {
403
+ var res = await _receiveTcs . Task ;
404
+ // Handle zero-byte reads
405
+ if ( buffer . Count == 0 )
406
+ {
407
+ return res . Item1 ;
408
+ }
409
+ _receiveTcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
410
+ res . Item2 . CopyTo ( buffer ) ;
411
+ return res . Item1 ;
412
+ }
413
+
414
+ public override Task SendAsync ( ArraySegment < byte > buffer , WebSocketMessageType messageType , bool endOfMessage , CancellationToken cancellationToken )
415
+ {
416
+ return _inner . SendAsync ( buffer , messageType , endOfMessage , cancellationToken ) ;
417
+ }
418
+ }
258
419
}
0 commit comments