Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Identity] Refactor synchronous ClientAssertionCredential #40277

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

### Other Changes

- Updated the synchronous `ClientAssertionCredential` to use the Microsoft Authentication Library (MSAL) for its implementation.

## 1.21.0 (2025-03-11)

### Other Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import Callable, Optional, Any
from typing import Callable, Any

from azure.core.credentials import AccessTokenInfo
from .._internal import AadClient
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.client_credential_base import ClientCredentialBase


class ClientAssertionCredential(GetTokenMixin):
class ClientAssertionCredential(ClientCredentialBase):
"""Authenticates a service principal with a JWT assertion.

This credential is for advanced scenarios. :class:`~azure.identity.CertificateCredential` has a more
Expand Down Expand Up @@ -42,36 +40,16 @@ class ClientAssertionCredential(GetTokenMixin):
"""

def __init__(self, tenant_id: str, client_id: str, func: Callable[[], str], **kwargs: Any) -> None:
self._func = func
authority = kwargs.pop("authority", None)
client_credential = {
"client_assertion": func,
}
cache = kwargs.pop("cache", None)
cae_cache = kwargs.pop("cae_cache", None)
additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", None)
self._client = AadClient(
tenant_id,
client_id,
authority=authority,
cache=cache,
cae_cache=cae_cache,
additionally_allowed_tenants=additionally_allowed_tenants,
super().__init__(
client_id=client_id,
client_credential=client_credential,
tenant_id=tenant_id,
_cache=cache,
_cae_cache=cae_cache,
**kwargs
)
super().__init__()

def __enter__(self) -> "ClientAssertionCredential":
self._client.__enter__()
return self

def __exit__(self, *args: Any) -> None:
self._client.__exit__(*args)

def close(self) -> None:
self.__exit__()

def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
return self._client.get_cached_access_token(scopes, **kwargs)

def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
assertion = self._func()
token = self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs)
return token
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _get_known_kwargs(kwargs: Dict[str, Any]):


class ClientCredentialBase(MsalCredential, GetTokenMixin):
"""Base class for credentials authenticating a service principal with a certificate or secret"""
"""Base class for credentials authenticating a service principal with a certificate, secret, or JWT assertion."""

@wrap_exceptions
def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
# ------------------------------------
from typing import Callable
from unittest.mock import MagicMock, Mock, patch
from urllib.parse import urlparse

from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION
from azure.identity import ClientAssertionCredential, TokenCachePersistenceOptions
import pytest

from helpers import build_aad_response, mock_response, GET_TOKEN_METHODS
from helpers import build_aad_response, mock_response, get_discovery_response, GET_TOKEN_METHODS


def test_init_with_kwargs():
tenant_id: str = "TENANT_ID"
tenant_id: str = "tenant-id"
client_id: str = "CLIENT_ID"
func: Callable[[], str] = lambda: "TOKEN"

Expand All @@ -26,7 +27,7 @@ def test_init_with_kwargs():


def test_context_manager():
tenant_id: str = "TENANT_ID"
tenant_id: str = "tenant-id"
client_id: str = "CLIENT_ID"
func: Callable[[], str] = lambda: "TOKEN"

Expand All @@ -46,22 +47,30 @@ def test_token_cache_persistence(get_token_method):
"""The credential should use a persistent cache if cache_persistence_options are configured."""

access_token = "foo"
tenant_id: str = "TENANT_ID"
tenant_id: str = "tenant-id"
client_id: str = "CLIENT_ID"
scope = "scope"
assertion = "ASSERTION_TOKEN"
func: Callable[[], str] = lambda: assertion

def send(request, **kwargs):
# ensure the `claims` and `tenant_id` keywords from credential's `get_token` method don't make it to transport
assert "claims" not in kwargs
assert "tenant_id" not in kwargs
parsed = urlparse(request.url)
tenant = parsed.path.split("/")[1]
if "/oauth2/v2.0/token" not in parsed.path:
return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant))

assert request.data["client_assertion"] == assertion
assert request.data["client_assertion_type"] == JWT_BEARER_ASSERTION
assert request.data["client_id"] == client_id
assert request.data["grant_type"] == "client_credentials"
assert request.data["scope"] == scope

assert tenant_id in request.url
return mock_response(json_payload=build_aad_response(access_token=access_token))

with patch("azure.identity._internal.aad_client_base._load_persistent_cache") as load_persistent_cache:
with patch("azure.identity._internal.msal_credentials._load_persistent_cache") as load_persistent_cache:
credential = ClientAssertionCredential(
tenant_id=tenant_id,
client_id=client_id,
Expand All @@ -71,18 +80,37 @@ def send(request, **kwargs):
)

assert load_persistent_cache.call_count == 0
assert credential._client._cache is None
assert credential._client._cae_cache is None
assert credential._cache is None
assert credential._cae_cache is None

token = getattr(credential, get_token_method)(scope)
assert token.token == access_token
assert load_persistent_cache.call_count == 1
assert credential._client._cache is not None
assert credential._client._cae_cache is None
assert credential._cache is not None
assert credential._cae_cache is None

kwargs = {"enable_cae": True}
if get_token_method == "get_token_info":
kwargs = {"options": kwargs}
token = getattr(credential, get_token_method)(scope, **kwargs)
assert load_persistent_cache.call_count == 2
assert credential._client._cae_cache is not None
assert credential._cae_cache is not None


def test_client_capabilities():
"""The credential should use the CAE-enabled client when enable_cae is True"""
assertion = "ASSERTION_TOKEN"
transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
func: Callable[[], str] = lambda: assertion
credential = ClientAssertionCredential("tenant-id", "client-id", func, transport=transport)
with patch("msal.ConfidentialClientApplication") as ConfidentialClientApplication:
credential._get_app()

assert ConfidentialClientApplication.call_count == 1
_, kwargs = ConfidentialClientApplication.call_args
assert kwargs["client_capabilities"] == None

credential._get_app(enable_cae=True)
assert ConfidentialClientApplication.call_count == 2
_, kwargs = ConfidentialClientApplication.call_args
assert kwargs["client_capabilities"] == ["CP1"]
78 changes: 30 additions & 48 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
from azure.identity._internal import within_credential_chain
import pytest

from helpers import build_aad_response, validating_transport, mock_response, Request, GET_TOKEN_METHODS
from helpers import (
build_aad_response,
validating_transport,
msal_validating_transport,
mock_response,
Request,
GET_TOKEN_METHODS,
)

MANAGED_IDENTITY_ENVIRON = "azure.identity._credentials.managed_identity.os.environ"
ALL_ENVIRONMENTS = (
Expand All @@ -28,7 +35,7 @@
EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...",
},
{EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc
{ # token exchange
{ # Workload Identity
EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://localhost",
EnvironmentVariables.AZURE_CLIENT_ID: "...",
EnvironmentVariables.AZURE_TENANT_ID: "...",
Expand Down Expand Up @@ -93,49 +100,19 @@ def test_custom_hooks(environ, get_token_method):
"token_type": "Bearer",
}
)
transport = validating_transport(requests=[Request()] * 2, responses=[expected_response] * 2)
if not environ.get(EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE):
transport = validating_transport(requests=[Request()] * 2, responses=[expected_response] * 2)
else:
transport = msal_validating_transport(requests=[Request()], responses=[expected_response])

with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True):
credential = ManagedIdentityCredential(
transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook
)
getattr(credential, get_token_method)(scope)

assert request_hook.call_count == 1
assert response_hook.call_count == 1
args, kwargs = response_hook.call_args
pipeline_response = args[0]
assert pipeline_response.http_response == expected_response


@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS))
def test_tenant_id(environ, get_token_method):
scope = "scope"
expected_token = "***"
request_hook = mock.Mock()
response_hook = mock.Mock()
now = int(time.time())
expected_response = mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 3600,
"expires_on": now + 3600,
"ext_expires_in": 3600,
"not_before": now,
"resource": scope,
"token_type": "Bearer",
}
)
transport = validating_transport(requests=[Request()] * 2, responses=[expected_response] * 2)

with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True):
credential = ManagedIdentityCredential(
transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook
)
getattr(credential, get_token_method)(scope)

assert request_hook.call_count == 1
assert response_hook.call_count == 1
assert request_hook.call_count >= 1
assert response_hook.call_count >= 1
args, kwargs = response_hook.call_args
pipeline_response = args[0]
assert pipeline_response.http_response == expected_response
Expand Down Expand Up @@ -843,14 +820,14 @@ def test_service_fabric_with_client_id_error(get_token_method):


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_token_exchange(tmpdir, get_token_method):
def test_workload_identity(tmpdir, get_token_method):
exchange_token = "exchange-token"
token_file = tmpdir.join("token")
token_file.write(exchange_token)
access_token = "***"
authority = "https://localhost"
default_client_id = "default_client_id"
tenant = "tenant_id"
tenant = "tenant-id"
scope = "scope"

success_response = mock_response(
Expand All @@ -864,7 +841,8 @@ def test_token_exchange(tmpdir, get_token_method):
"token_type": "Bearer",
}
)
transport = validating_transport(

transport = msal_validating_transport(
requests=[
Request(
base_url=authority,
Expand All @@ -879,6 +857,7 @@ def test_token_exchange(tmpdir, get_token_method):
)
],
responses=[success_response],
endpoint=authority,
)

mock_environ = {
Expand All @@ -895,7 +874,7 @@ def test_token_exchange(tmpdir, get_token_method):

# client_id kwarg should override AZURE_CLIENT_ID
nondefault_client_id = "non" + default_client_id
transport = validating_transport(
transport = msal_validating_transport(
requests=[
Request(
base_url=authority,
Expand All @@ -910,6 +889,7 @@ def test_token_exchange(tmpdir, get_token_method):
)
],
responses=[success_response],
endpoint=authority,
)

with mock.patch.dict("os.environ", mock_environ, clear=True):
Expand All @@ -918,7 +898,7 @@ def test_token_exchange(tmpdir, get_token_method):
assert token.token == access_token

# AZURE_CLIENT_ID may not have a value, in which case client_id is required
transport = validating_transport(
transport = msal_validating_transport(
requests=[
Request(
base_url=authority,
Expand All @@ -933,6 +913,7 @@ def test_token_exchange(tmpdir, get_token_method):
)
],
responses=[success_response],
endpoint=authority,
)

with mock.patch.dict(
Expand All @@ -953,14 +934,14 @@ def test_token_exchange(tmpdir, get_token_method):


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_token_exchange_tenant_id(tmpdir, get_token_method):
def test_workload_identity_tenant_id(tmpdir, get_token_method):
exchange_token = "exchange-token"
token_file = tmpdir.join("token")
token_file.write(exchange_token)
access_token = "***"
authority = "https://localhost"
default_client_id = "default_client_id"
tenant = "tenant_id"
tenant = "tenant-id"
scope = "scope"

success_response = mock_response(
Expand All @@ -974,7 +955,7 @@ def test_token_exchange_tenant_id(tmpdir, get_token_method):
"token_type": "Bearer",
}
)
transport = validating_transport(
transport = msal_validating_transport(
requests=[
Request(
base_url=authority,
Expand All @@ -989,6 +970,7 @@ def test_token_exchange_tenant_id(tmpdir, get_token_method):
)
],
responses=[success_response],
endpoint=authority,
)

mock_environ = {
Expand All @@ -998,8 +980,8 @@ def test_token_exchange_tenant_id(tmpdir, get_token_method):
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
}
with mock.patch.dict("os.environ", mock_environ, clear=True):
credential = ManagedIdentityCredential(transport=transport)
kwargs = {"tenant_id": "tenant_id"}
credential = ManagedIdentityCredential(transport=transport, additionally_allowed_tenants=["*"])
kwargs = {"tenant_id": "tenant-id-2"}
if get_token_method == "get_token_info":
kwargs = {"options": kwargs}
token = getattr(credential, get_token_method)(scope, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def test_tenant_id(environ, get_token_method):
credential = ManagedIdentityCredential(
transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook
)
await getattr(credential, get_token_method)(scope)
await getattr(credential, get_token_method)(scope)

assert request_hook.call_count == 1
assert response_hook.call_count == 1
Expand Down
Loading
Loading