Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class AdalConfiguration

public string ResourceClientUri { get; set; }

public UserIdentifierType UserIdentifier { get; set; }

public TokenCache TokenCache { get; set; }

public AdalConfiguration()
Expand All @@ -58,6 +60,7 @@ public AdalConfiguration()
ValidateAuthority = true;
AdEndpoint = string.Empty;
ResourceClientUri = "https://management.core.windows.net/";
UserIdentifier = UserIdentifierType.OptionalDisplayableId;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public interface IAccessToken

string UserId { get; }

string UniqueId { get; }

string TenantId { get; }

LoginType LoginType { get; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ public void AuthorizeRequest(Action<string, string> authTokenSetter)

public string AccessToken { get { return AuthResult.AccessToken; } }

public string UniqueId { get { return AuthResult.UserInfo.UniqueId; } }

public LoginType LoginType { get { return LoginType.OrgId; } }

public string TenantId { get { return this.Configuration.AdDomain; } }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
namespace Microsoft.Azure.Commands.Common.Authentication
{
/// <summary>
/// A token provider that uses ADAL to retrieve
/// tokens from Azure Active Directory for user
/// A token provider that uses ADAL to retrieve tokens from Azure Active Directory for user
/// credentials.
/// </summary>
internal class UserTokenProvider : ITokenProvider
Expand All @@ -51,12 +50,12 @@ public IAccessToken GetAccessToken(
throw new ArgumentException(string.Format(Resources.InvalidCredentialType, "User"), "credentialType");
}

return new AdalAccessToken(AcquireToken(config, promptBehavior, userId, password), this, config);
return new UserAccessToken(AcquireToken(config, promptBehavior, userId, password), this, config);
}

private readonly static TimeSpan expirationThreshold = TimeSpan.FromMinutes(5);

private bool IsExpired(AdalAccessToken token)
private bool IsExpired(UserAccessToken token)
{
#if DEBUG
if (Environment.GetEnvironmentVariable("FORCE_EXPIRED_ACCESS_TOKEN") != null)
Expand All @@ -72,7 +71,7 @@ private bool IsExpired(AdalAccessToken token)
return timeUntilExpiration < expirationThreshold;
}

private void Renew(AdalAccessToken token)
private void Renew(UserAccessToken token)
{
TracingAdapter.Information(
Resources.UPNRenewTokenTrace,
Expand Down Expand Up @@ -248,7 +247,7 @@ private AuthenticationResult DoAcquireToken(
config.ClientId,
config.ClientRedirectUri,
promptBehavior,
new UserIdentifier(userId, UserIdentifierType.RequiredDisplayableId),
new UserIdentifier(userId, config.UserIdentifier),
AdalConfiguration.EnableEbdMagicCookie);
}
else
Expand All @@ -273,13 +272,13 @@ private string GetExceptionMessage(Exception ex)
/// <summary>
/// Implementation of <see cref="IAccessToken"/> using data from ADAL
/// </summary>
private class AdalAccessToken : IAccessToken
private class UserAccessToken : IAccessToken
{
internal readonly AdalConfiguration Configuration;
internal AuthenticationResult AuthResult;
private readonly UserTokenProvider tokenProvider;

public AdalAccessToken(AuthenticationResult authResult, UserTokenProvider tokenProvider, AdalConfiguration configuration)
public UserAccessToken(AuthenticationResult authResult, UserTokenProvider tokenProvider, AdalConfiguration configuration)
{
AuthResult = authResult;
this.tokenProvider = tokenProvider;
Expand All @@ -295,6 +294,8 @@ public void AuthorizeRequest(Action<string, string> authTokenSetter)
public string AccessToken { get { return AuthResult.AccessToken; } }

public string UserId { get { return AuthResult.UserInfo.DisplayableId; } }

public string UniqueId { get { return AuthResult.UserInfo.UniqueId; } }

public string TenantId { get { return AuthResult.TenantId; } }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ public IAccessToken Authenticate(
{
IAccessToken token;
var configuration = GetAdalConfiguration(environment, tenant, resourceId, tokenCache);
if(account.GetTenantUniqueId(tenant) != null)
{
configuration.UserIdentifier = UserIdentifierType.UniqueId;
}

TracingAdapter.Information(
Resources.AdalAuthConfigurationTrace,
Expand All @@ -58,14 +62,24 @@ public IAccessToken Authenticate(
if (account.IsPropertySet(AzureAccount.Property.CertificateThumbprint))
{
var thumbprint = account.GetProperty(AzureAccount.Property.CertificateThumbprint);
token = TokenProvider.GetAccessTokenWithCertificate(configuration, account.Id, thumbprint, account.Type);
token = TokenProvider.GetAccessTokenWithCertificate(
configuration,
account.GetTenantUniqueId(tenant) ?? account.Id,
thumbprint,
account.Type);
}
else
{
token = TokenProvider.GetAccessToken(configuration, promptBehavior, account.Id, password, account.Type);
token = TokenProvider.GetAccessToken(
configuration,
promptBehavior,
account.GetTenantUniqueId(tenant) ?? account.Id,
password,
account.Type);
}

account.Id = token.UserId;
account.SetTenantUniqueId(token.TenantId, token.UniqueId);
return token;
}

Expand Down Expand Up @@ -298,8 +312,11 @@ public ServiceClientCredentials GetServiceClientCredentials(AzureContext context
}
}

private AdalConfiguration GetAdalConfiguration(AzureEnvironment environment, string tenantId,
AzureEnvironment.Endpoint resourceId, TokenCache tokenCache)
private AdalConfiguration GetAdalConfiguration(
AzureEnvironment environment,
string tenantId,
AzureEnvironment.Endpoint resourceId,
TokenCache tokenCache)
{
if (environment == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// limitations under the License.
// ----------------------------------------------------------------------------------

using Microsoft.Azure.Commands.Common.Authentication.Factories;
using Microsoft.Azure.Commands.Common.Authentication.Utilities;
using System;
using System.Collections.Generic;
Expand All @@ -24,6 +25,7 @@ public partial class AzureAccount
{
public AzureAccount()
{
TenantToUniqueId = new Dictionary<string, string>();
Properties = new Dictionary<Property, string>();
}

Expand Down Expand Up @@ -124,6 +126,24 @@ public void RemoveSubscription(Guid id)
}
}

public string GetTenantUniqueId(string tenantId)
{
if (TenantToUniqueId.ContainsKey(tenantId))
{
return TenantToUniqueId[tenantId];
}

return null;
}

public void SetTenantUniqueId(string tenantId, string uniqueId)
{
if (!string.IsNullOrWhiteSpace(tenantId))
{
TenantToUniqueId[tenantId] = uniqueId;
}
}

public override bool Equals(object obj)
{
var anotherAccount = obj as AzureAccount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public partial class AzureAccount

public Dictionary<Property, string> Properties { get; set; }

private Dictionary<string,string> TenantToUniqueId { get; set; }

public enum AccountType
{
Certificate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ public void AuthorizeRequest(Action<string, string> authTokenSetter)
}

public string AccessToken { get; set; }

public string UserId { get; set; }

public string UniqueId { get { return this.UserId; } }

public LoginType LoginType { get; set; }

public string TenantId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@ namespace Microsoft.WindowsAzure.Commands.Common.Test.Mocks
public class MockAccessToken : IAccessToken
{
private string _tenantId = String.Empty;

public void AuthorizeRequest(Action<string, string> authTokenSetter)
{
authTokenSetter("Bearer", AccessToken);
}

public string AccessToken { get; set; }

public string UserId { get; set; }

public string UniqueId { get; set; }

public LoginType LoginType { get; set; }

public string TenantId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public IAccessToken Authenticate(
{
UserId = account.Id,
LoginType = LoginType.OrgId,
AccessToken = "123"
AccessToken = "123",
UniqueId = "UniqueId"
};

return token;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

using Hyak.Common;
using Microsoft.Azure.Commands.Common.Authentication;
using Microsoft.Azure.Commands.Common.Authentication.Factories;
using Microsoft.Azure.Commands.Common.Authentication.Models;
using Microsoft.Azure.Commands.Profile;
using Microsoft.Azure.Commands.Profile.Models;
Expand All @@ -24,11 +25,13 @@
using Microsoft.WindowsAzure.Commands.Common.Test.Mocks;
using Microsoft.WindowsAzure.Commands.ScenarioTest;
using Microsoft.WindowsAzure.Commands.Utilities.Common;
using Moq;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation;
using System.Security;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
Expand Down Expand Up @@ -103,7 +106,7 @@ public void SpecifyTenantAndSubscriptionIdSucceed()
null,
null);
}

[Fact]
[Trait(Category.AcceptanceType, Category.CheckIn)]
public void SubscriptionIdNotExist()
Expand Down Expand Up @@ -456,8 +459,7 @@ public void AzurePSComletMessageQueue()

Assert.Equal(500, queue.Count);
}



[Fact]
[Trait(Category.AcceptanceType, Category.CheckIn)]
public void GetAzureRmSubscriptionPaginatedResult()
Expand Down Expand Up @@ -491,5 +493,69 @@ public void GetAzureRmSubscriptionPaginatedResult()
Assert.Equal("Disabled", ((PSAzureSubscription)commandRuntimeMock.OutputPipeline[2]).State);
Assert.Equal("LinkToNextPage", ((PSAzureSubscription)commandRuntimeMock.OutputPipeline[2]).SubscriptionName);
}

[Fact]
[Trait(Category.AcceptanceType, Category.CheckIn)]
public void VerifyAdalUniqueIdUsage()
{
var tenants = new List<string> { DefaultTenant.ToString(), Guid.NewGuid().ToString() };
var firstList = new List<string> { DefaultSubscription.ToString(), Guid.NewGuid().ToString() };
var secondList = new List<string> { Guid.NewGuid().ToString() };
var client = SetupTestEnvironment(tenants, firstList, secondList);

var authFactory = new AuthenticationFactory();
var tokenProvider = new Mock<ITokenProvider>();
int iterator = 0;
tokenProvider.Setup(m =>
m.GetAccessToken(
It.IsAny<AdalConfiguration>(),
It.IsAny<ShowDialog>(),
It.IsAny<string>(),
null,
AzureAccount.AccountType.User))
.Returns(
(AdalConfiguration config,
ShowDialog dialog,
string userId,
SecureString password,
AzureAccount.AccountType credentialType) =>
{
var accessToken = new MockAccessToken()
{
AccessToken = Guid.NewGuid().ToString(),
UserId = DefaultAccount,
};

if (tenants.Contains(config.AdDomain) &&
userId != DefaultAccount)
{
Assert.Equal(UserIdentifierType.UniqueId, config.UserIdentifier);
accessToken.TenantId = config.AdDomain;
accessToken.UniqueId = "uniqueID_" + config.AdDomain;
}
else
{
accessToken.TenantId = tenants[iterator];
accessToken.UniqueId = "uniqueID_" + tenants[iterator++];
Assert.Equal(UserIdentifierType.OptionalDisplayableId, config.UserIdentifier);
}
return accessToken;
});


authFactory.TokenProvider = tokenProvider.Object;
AzureSession.AuthenticationFactory = authFactory;

var azureRmProfile = client.Login(
Context.Account,
Context.Environment,
null,
null,
null,
null);

Assert.Equal("uniqueID_" + tenants[0], azureRmProfile.Context.Account.GetTenantUniqueId(tenants[0]));
Assert.Equal("uniqueID_" + tenants[1], azureRmProfile.Context.Account.GetTenantUniqueId(tenants[1]));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ public void LoginWithNoSubscriptionAndNoTenant()
cmdlt.InvokeEndProcessing();

Assert.NotNull(AzureRmProfileProvider.Instance.Profile.Context);
Assert.NotNull(AzureRmProfileProvider.Instance.Profile.Context.Account);
Assert.NotNull(AzureRmProfileProvider.Instance.Profile.Context.Account.GetTenantUniqueId("72f988bf-86f1-41af-91ab-2d7cd011db47"));
Assert.Equal("microsoft.com", AzureRmProfileProvider.Instance.Profile.Context.Tenant.Domain);
}

Expand Down
Loading