Skip to content
This repository was archived by the owner on Dec 18, 2018. It is now read-only.

Commit 665f166

Browse files
ivankarpeydavidfowl
authored andcommitted
fix issue with incorrect user detection when Invoking for User (#747)
* fix issue with incorrect user detection when Invoking for User * fix failed testcases * use proper extension method to avoid potential null reference exception * fix for channel name in redis version + follow SignalR team recommendations * remove unncessary freespace * remove whitespaces * introduce IUserIdProvider to resolve user id * Move IUserIdProvider from HubLifetimeManager to HubConnectionContext * setting user id to connection context in hubendpoint
1 parent 3c5d283 commit 665f166

File tree

10 files changed

+55
-14
lines changed

10 files changed

+55
-14
lines changed

src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs

+2-5
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,14 @@ public override Task InvokeGroupAsync(string groupName, string methodName, objec
138138

139139
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
140140
{
141-
return InvokeAllWhere(methodName, args, connection =>
142-
{
143-
return string.Equals(connection.User.Identity.Name, userId, StringComparison.Ordinal);
144-
});
141+
return InvokeAllWhere(methodName, args, connection =>
142+
string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal));
145143
}
146144

147145
public override Task OnConnectedAsync(HubConnectionContext connection)
148146
{
149147
// Set the hub groups feature
150148
connection.Features.Set<IHubGroupsFeature>(new HubGroupsFeature());
151-
152149
_connections.Add(connection);
153150
return Task.CompletedTask;
154151
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System.Security.Claims;
5+
6+
namespace Microsoft.AspNetCore.SignalR.Core
7+
{
8+
public class DefaultUserIdProvider : IUserIdProvider
9+
{
10+
public string GetUserId(HubConnectionContext connection)
11+
{
12+
return connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
13+
}
14+
}
15+
}

src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ public virtual void Abort()
6767
Task.Factory.StartNew(_abortedCallback, this);
6868
}
6969

70+
public string UserIdentifier { get; internal set; }
71+
7072
internal void Abort(Exception exception)
7173
{
7274
AbortException = ExceptionDispatchInfo.Capture(exception);

src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using System.Threading.Tasks;
1212
using System.Threading.Tasks.Channels;
1313
using Microsoft.AspNetCore.Authorization;
14+
using Microsoft.AspNetCore.SignalR.Core;
1415
using Microsoft.AspNetCore.SignalR.Core.Internal;
1516
using Microsoft.AspNetCore.SignalR.Features;
1617
using Microsoft.AspNetCore.SignalR.Internal;
@@ -39,20 +40,23 @@ public class HubEndPoint<THub> : IInvocationBinder where THub : Hub
3940
private readonly IServiceScopeFactory _serviceScopeFactory;
4041
private readonly IHubProtocolResolver _protocolResolver;
4142
private readonly IOptions<HubOptions> _hubOptions;
43+
private readonly IUserIdProvider _userIdProvider;
4244

4345
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
4446
IHubProtocolResolver protocolResolver,
4547
IHubContext<THub> hubContext,
4648
IOptions<HubOptions> hubOptions,
4749
ILogger<HubEndPoint<THub>> logger,
48-
IServiceScopeFactory serviceScopeFactory)
50+
IServiceScopeFactory serviceScopeFactory,
51+
IUserIdProvider userIdProvider)
4952
{
5053
_protocolResolver = protocolResolver;
5154
_lifetimeManager = lifetimeManager;
5255
_hubContext = hubContext;
5356
_hubOptions = hubOptions;
5457
_logger = logger;
5558
_serviceScopeFactory = serviceScopeFactory;
59+
_userIdProvider = userIdProvider;
5660

5761
DiscoverHubMethods();
5862
}
@@ -72,6 +76,8 @@ public async Task OnConnectedAsync(ConnectionContext connection)
7276
return;
7377
}
7478

79+
connectionContext.UserIdentifier = _userIdProvider.GetUserId(connectionContext);
80+
7581
// Hubs support multiple producers so we set up this loop to copy
7682
// data written to the HubConnectionContext's channel to the transport channel
7783
var protocolReaderWriter = connectionContext.ProtocolReaderWriter;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
namespace Microsoft.AspNetCore.SignalR.Core
5+
{
6+
public interface IUserIdProvider
7+
{
8+
string GetUserId(HubConnectionContext connection);
9+
}
10+
}

src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using Microsoft.AspNetCore.SignalR;
5+
using Microsoft.AspNetCore.SignalR.Core;
56
using Microsoft.AspNetCore.SignalR.Internal;
67

78
namespace Microsoft.Extensions.DependencyInjection
@@ -15,6 +16,7 @@ public static ISignalRBuilder AddSignalRCore(this IServiceCollection services)
1516
services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>));
1617
services.AddSingleton(typeof(IHubContext<,>), typeof(HubContext<,>));
1718
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
19+
services.AddSingleton(typeof(IUserIdProvider), typeof(DefaultUserIdProvider));
1820
services.AddScoped(typeof(IHubActivator<>), typeof(DefaultHubActivator<>));
1921

2022
services.AddAuthorization();

src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ public override Task OnConnectedAsync(HubConnectionContext connection)
245245
previousConnectionTask = WriteAsync(connection, message);
246246
});
247247

248-
if (connection.User.Identity.IsAuthenticated)
248+
if (!string.IsNullOrEmpty(connection.UserIdentifier))
249249
{
250-
var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name;
250+
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
251251
redisSubscriptions.Add(userChannel);
252252

253253
var previousUserTask = Task.CompletedTask;

test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Threading.Tasks;
22
using System.Threading.Tasks.Channels;
3+
using Microsoft.AspNetCore.SignalR.Core;
34
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
45
using Xunit;
56

test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -850,15 +850,15 @@ public async Task HubsCanSendToUser(Type hubType)
850850

851851
dynamic endPoint = serviceProvider.GetService(GetEndPointType(hubType));
852852

853-
using (var firstClient = new TestClient())
854-
using (var secondClient = new TestClient())
853+
using (var firstClient = new TestClient(addClaimId: true))
854+
using (var secondClient = new TestClient(addClaimId: true))
855855
{
856856
Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection);
857857
Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection);
858858

859859
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
860860

861-
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.Identity.Name, "test").OrTimeout();
861+
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout();
862862

863863
// check that 'secondConnection' has received the group send
864864
var hubMessage = await secondClient.ReadAsync().OrTimeout();

test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs

+11-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public class TestClient : IDisposable, IInvocationBinder
2828
public Channel<byte[]> Application { get; }
2929
public Task Connected => ((TaskCompletionSource<bool>)Connection.Metadata["ConnectedTask"]).Task;
3030

31-
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null)
31+
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, bool addClaimId = false)
3232
{
3333
var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks };
3434
var transportToApplication = Channel.CreateUnbounded<byte[]>(options);
@@ -38,7 +38,15 @@ public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = nul
3838
_transport = ChannelConnection.Create<byte[]>(input: transportToApplication, output: applicationToTransport);
3939

4040
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application);
41-
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
41+
42+
var claimValue = Interlocked.Increment(ref _id).ToString();
43+
var claims = new List<Claim>{ new Claim(ClaimTypes.Name, claimValue) };
44+
if (addClaimId)
45+
{
46+
claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue));
47+
}
48+
49+
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims));
4250
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();
4351

4452
protocol = protocol ?? new JsonHubProtocol();
@@ -182,4 +190,4 @@ Type IInvocationBinder.GetReturnType(string invocationId)
182190
return typeof(object);
183191
}
184192
}
185-
}
193+
}

0 commit comments

Comments
 (0)