1
1
// Licensed to the .NET Foundation under one or more agreements.
2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
4
- using System ;
5
- using System . Collections . Generic ;
6
- using System . Threading . Tasks ;
4
+ using System . Net . WebSockets ;
7
5
using Microsoft . AspNetCore . Http . Connections ;
6
+ using Microsoft . AspNetCore . Http . Connections . Client ;
8
7
using Microsoft . AspNetCore . SignalR . Client ;
9
8
using Microsoft . AspNetCore . SignalR . Protocol ;
10
9
using Microsoft . AspNetCore . SignalR . Tests ;
11
10
using Microsoft . AspNetCore . Testing ;
12
11
using Microsoft . Extensions . DependencyInjection ;
13
12
using Microsoft . Extensions . Logging ;
14
- using Xunit ;
15
13
16
14
namespace Microsoft . AspNetCore . SignalR . StackExchangeRedis . Tests ;
17
15
@@ -211,7 +209,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit
211
209
}
212
210
}
213
211
214
- private static HubConnection CreateConnection ( string url , HttpTransportType transportType , IHubProtocol protocol , ILoggerFactory loggerFactory , string userName = null )
212
+ [ ConditionalTheory ]
213
+ [ SkipIfDockerNotPresent ]
214
+ [ InlineData ( "messagepack" ) ]
215
+ [ InlineData ( "json" ) ]
216
+ public async Task StatefulReconnectPreservesMessageFromOtherServer ( string protocolName )
217
+ {
218
+ using ( StartVerifiableLog ( ) )
219
+ {
220
+ var protocol = HubProtocolHelpers . GetHubProtocol ( protocolName ) ;
221
+
222
+ ClientWebSocket innerWs = null ;
223
+ WebSocketWrapper ws = null ;
224
+ TaskCompletionSource reconnectTcs = null ;
225
+ TaskCompletionSource startedReconnectTcs = null ;
226
+
227
+ var connection = CreateConnection ( _serverFixture . FirstServer . Url + "/stateful" , HttpTransportType . WebSockets , protocol , LoggerFactory ,
228
+ customizeConnection : builder =>
229
+ {
230
+ builder . WithStatefulReconnect ( ) ;
231
+ builder . Services . Configure < HttpConnectionOptions > ( o =>
232
+ {
233
+ // Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure
234
+ // Which will trigger the stateful reconnect flow
235
+ o . WebSocketFactory = async ( context , token ) =>
236
+ {
237
+ if ( reconnectTcs is null )
238
+ {
239
+ reconnectTcs = new TaskCompletionSource ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
240
+ startedReconnectTcs = new TaskCompletionSource ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
241
+ }
242
+ else
243
+ {
244
+ startedReconnectTcs . SetResult ( ) ;
245
+ // We only want to wait on the reconnect, not the initial connection attempt
246
+ await reconnectTcs . Task . DefaultTimeout ( ) ;
247
+ }
248
+
249
+ innerWs = new ClientWebSocket ( ) ;
250
+ ws = new WebSocketWrapper ( innerWs ) ;
251
+ await innerWs . ConnectAsync ( context . Uri , token ) ;
252
+
253
+ _ = Task . Run ( async ( ) =>
254
+ {
255
+ try
256
+ {
257
+ while ( innerWs . State == WebSocketState . Open )
258
+ {
259
+ var buffer = new byte [ 1024 ] ;
260
+ var res = await innerWs . ReceiveAsync ( buffer , default ) ;
261
+ ws . SetReceiveResult ( ( res , buffer . AsMemory ( 0 , res . Count ) ) ) ;
262
+ }
263
+ }
264
+ // Log but ignore receive errors, that likely just means the connection closed
265
+ catch ( Exception ex )
266
+ {
267
+ Logger . LogInformation ( ex , "Error while reading from inner websocket" ) ;
268
+ }
269
+ } ) ;
270
+
271
+ return ws ;
272
+ } ;
273
+ } ) ;
274
+ } ) ;
275
+ var secondConnection = CreateConnection ( _serverFixture . SecondServer . Url + "/stateful" , HttpTransportType . WebSockets , protocol , LoggerFactory ) ;
276
+
277
+ var tcs = new TaskCompletionSource < string > ( ) ;
278
+ connection . On < string > ( "SendToAll" , message => tcs . TrySetResult ( message ) ) ;
279
+
280
+ var tcs2 = new TaskCompletionSource < string > ( ) ;
281
+ secondConnection . On < string > ( "SendToAll" , message => tcs2 . TrySetResult ( message ) ) ;
282
+
283
+ await connection . StartAsync ( ) . DefaultTimeout ( ) ;
284
+ await secondConnection . StartAsync ( ) . DefaultTimeout ( ) ;
285
+
286
+ // Close first connection before the second connection sends a message to all clients
287
+ await ws . CloseOutputAsync ( WebSocketCloseStatus . InternalServerError , statusDescription : null , default ) ;
288
+ await startedReconnectTcs . Task . DefaultTimeout ( ) ;
289
+
290
+ // Send to all clients, since both clients are on different servers this means the backplane will be used
291
+ // And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects
292
+ // But is on a different server from the original message sender.
293
+ await secondConnection . SendAsync ( "SendToAll" , "test message" ) . DefaultTimeout ( ) ;
294
+
295
+ // Check that second connection still receives the message
296
+ Assert . Equal ( "test message" , await tcs2 . Task . DefaultTimeout ( ) ) ;
297
+ Assert . False ( tcs . Task . IsCompleted ) ;
298
+
299
+ // allow first connection to reconnect
300
+ reconnectTcs . SetResult ( ) ;
301
+
302
+ // Check that first connection received the message once it reconnected
303
+ Assert . Equal ( "test message" , await tcs . Task . DefaultTimeout ( ) ) ;
304
+
305
+ await connection . DisposeAsync ( ) . DefaultTimeout ( ) ;
306
+ }
307
+ }
308
+
309
+ private static HubConnection CreateConnection ( string url , HttpTransportType transportType , IHubProtocol protocol , ILoggerFactory loggerFactory , string userName = null ,
310
+ Action < IHubConnectionBuilder > customizeConnection = null )
215
311
{
216
312
var hubConnectionBuilder = new HubConnectionBuilder ( )
217
313
. WithLoggerFactory ( loggerFactory )
@@ -225,6 +321,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran
225
321
226
322
hubConnectionBuilder . Services . AddSingleton ( protocol ) ;
227
323
324
+ customizeConnection ? . Invoke ( hubConnectionBuilder ) ;
325
+
228
326
return hubConnectionBuilder . Build ( ) ;
229
327
}
230
328
@@ -253,4 +351,67 @@ public static IEnumerable<object[]> TransportTypesAndProtocolTypes
253
351
}
254
352
}
255
353
}
354
+
355
+ internal sealed class WebSocketWrapper : WebSocket
356
+ {
357
+ private readonly WebSocket _inner ;
358
+ private TaskCompletionSource < ( WebSocketReceiveResult , ReadOnlyMemory < byte > ) > _receiveTcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
359
+
360
+ public WebSocketWrapper ( WebSocket inner )
361
+ {
362
+ _inner = inner ;
363
+ }
364
+
365
+ public override WebSocketCloseStatus ? CloseStatus => _inner . CloseStatus ;
366
+
367
+ public override string CloseStatusDescription => _inner . CloseStatusDescription ;
368
+
369
+ public override WebSocketState State => _inner . State ;
370
+
371
+ public override string SubProtocol => _inner . SubProtocol ;
372
+
373
+ public override void Abort ( )
374
+ {
375
+ _inner . Abort ( ) ;
376
+ }
377
+
378
+ public override Task CloseAsync ( WebSocketCloseStatus closeStatus , string statusDescription , CancellationToken cancellationToken )
379
+ {
380
+ return _inner . CloseAsync ( closeStatus , statusDescription , cancellationToken ) ;
381
+ }
382
+
383
+ public override Task CloseOutputAsync ( WebSocketCloseStatus closeStatus , string statusDescription , CancellationToken cancellationToken )
384
+ {
385
+ _receiveTcs . TrySetException ( new IOException ( "force reconnect" ) ) ;
386
+ return Task . CompletedTask ;
387
+ }
388
+
389
+ public override void Dispose ( )
390
+ {
391
+ _inner . Dispose ( ) ;
392
+ }
393
+
394
+ public void SetReceiveResult ( ( WebSocketReceiveResult , ReadOnlyMemory < byte > ) result )
395
+ {
396
+ _receiveTcs . SetResult ( result ) ;
397
+ }
398
+
399
+ public override async Task < WebSocketReceiveResult > ReceiveAsync ( ArraySegment < byte > buffer , CancellationToken cancellationToken )
400
+ {
401
+ var res = await _receiveTcs . Task ;
402
+ // Handle zero-byte reads
403
+ if ( buffer . Count == 0 )
404
+ {
405
+ return res . Item1 ;
406
+ }
407
+ _receiveTcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
408
+ res . Item2 . CopyTo ( buffer ) ;
409
+ return res . Item1 ;
410
+ }
411
+
412
+ public override Task SendAsync ( ArraySegment < byte > buffer , WebSocketMessageType messageType , bool endOfMessage , CancellationToken cancellationToken )
413
+ {
414
+ return _inner . SendAsync ( buffer , messageType , endOfMessage , cancellationToken ) ;
415
+ }
416
+ }
256
417
}
0 commit comments