Skip to content

Add GetTokenOptions #2629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/core/azure_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ sha2 = { workspace = true, optional = true }
tokio = { workspace = true, optional = true }
tracing.workspace = true
typespec = { workspace = true, features = ["http", "json"] }
typespec_client_core = { workspace = true, features = ["http", "json"] }
typespec_client_core = { workspace = true, features = ["derive", "http", "json"] }

[build-dependencies]
rustc_version.workspace = true
Expand Down
14 changes: 11 additions & 3 deletions sdk/core/azure_core/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use serde::{Deserialize, Serialize};
use std::{borrow::Cow, fmt::Debug};
use typespec_client_core::date::OffsetDateTime;
use typespec_client_core::{date::OffsetDateTime, fmt::SafeDebug};

/// Default Azure authorization scope.
pub static DEFAULT_SCOPE_SUFFIX: &str = "/.default";
Expand Down Expand Up @@ -87,10 +87,18 @@ impl AccessToken {
}
}

/// Options for getting a token from a [`TokenCredential`]
#[derive(Clone, Default, SafeDebug)]
pub struct TokenRequestOptions;

/// Represents a credential capable of providing an OAuth token.
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
pub trait TokenCredential: Send + Sync + Debug {
/// Gets a `AccessToken` for the specified resource
async fn get_token(&self, scopes: &[&str]) -> crate::Result<AccessToken>;
/// Gets an [`AccessToken`] for the specified scopes
async fn get_token(
&self,
scopes: &[&str],
options: Option<TokenRequestOptions>,
) -> crate::Result<AccessToken>;
}
12 changes: 8 additions & 4 deletions sdk/core/azure_core/src/http/policies/bearer_token_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl Policy for BearerTokenCredentialPolicy {
drop(access_token);
let mut access_token = self.access_token.write().await;
if access_token.is_none() {
*access_token = Some(self.credential.get_token(&self.scopes()).await?);
*access_token = Some(self.credential.get_token(&self.scopes(), None).await?);
}
}
Some(token) if should_refresh(&token.expires_on) => {
Expand All @@ -79,7 +79,7 @@ impl Policy for BearerTokenCredentialPolicy {
// access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic
if access_token.is_none() || access_token.as_ref().unwrap().expires_on == expires_on
{
match self.credential.get_token(&self.scopes()).await {
match self.credential.get_token(&self.scopes(), None).await {
Ok(new_token) => {
*access_token = Some(new_token);
}
Expand Down Expand Up @@ -121,7 +121,7 @@ fn should_refresh(expires_on: &OffsetDateTime) -> bool {
mod tests {
use super::*;
use crate::{
credentials::{Secret, TokenCredential},
credentials::{Secret, TokenCredential, TokenRequestOptions},
http::{
headers::{Headers, AUTHORIZATION},
policies::Policy,
Expand Down Expand Up @@ -172,7 +172,11 @@ mod tests {
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl TokenCredential for MockCredential {
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> {
async fn get_token(
&self,
_: &[&str],
_: Option<TokenRequestOptions>,
) -> Result<AccessToken> {
let i = self.calls.fetch_add(1, Ordering::SeqCst);
self.tokens
.get(i)
Expand Down
8 changes: 6 additions & 2 deletions sdk/core/azure_core_test/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

//! Credentials for live and recorded tests.
use azure_core::{
credentials::{AccessToken, Secret, TokenCredential},
credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
date::OffsetDateTime,
error::ErrorKind,
};
Expand All @@ -17,7 +17,11 @@ pub struct MockCredential;
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for MockCredential {
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
async fn get_token(
&self,
scopes: &[&str],
_: Option<TokenRequestOptions>,
) -> azure_core::Result<AccessToken> {
let token: Secret = format!("TEST TOKEN {}", scopes.join(" ")).into();
let expires_on = OffsetDateTime::now_utc().saturating_add(
Duration::from_secs(60 * 5).try_into().map_err(|err| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async fn generate_authorization(
let token = match auth_token {
Credential::Token(token_credential) => {
let token = token_credential
.get_token(&[&scope_from_url(url)])
.get_token(&[&scope_from_url(url)], None)
.await?
.token
.secret()
Expand Down Expand Up @@ -146,7 +146,7 @@ mod tests {
use std::sync::Arc;

use azure_core::{
credentials::{AccessToken, TokenCredential},
credentials::{AccessToken, TokenCredential, TokenRequestOptions},
date,
http::Method,
};
Expand All @@ -168,7 +168,11 @@ mod tests {
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for TestTokenCredential {
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
async fn get_token(
&self,
scopes: &[&str],
_: Option<TokenRequestOptions>,
) -> azure_core::Result<AccessToken> {
let token = format!("{}+{}", self.0, scopes.join(","));
Ok(AccessToken::new(
token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl ConnectionManager {
debug!("Get Token.");
let token = self
.credential
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE])
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE], None)
.await?;

debug!("Token for path {path} expires at {}", token.expires_on);
Expand Down Expand Up @@ -368,7 +368,7 @@ impl ConnectionManager {

let new_token = self
.credential
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE])
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE], None)
.await?;

// Create an ephemeral session to host the authentication.
Expand Down Expand Up @@ -510,7 +510,7 @@ impl ConnectionManager {
mod tests {
use super::*;
use async_trait::async_trait;
use azure_core::{http::Url, Result};
use azure_core::{credentials::TokenRequestOptions, http::Url, Result};
use std::sync::Arc;
use time::OffsetDateTime;
use tracing::info;
Expand Down Expand Up @@ -551,7 +551,11 @@ mod tests {

#[async_trait]
impl TokenCredential for MockTokenCredential {
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> {
async fn get_token(
&self,
_scopes: &[&str],
_options: Option<TokenRequestOptions>,
) -> Result<AccessToken> {
// Simulate a token refresh by incrementing the token get count
// and updating the token expiration time
{
Expand Down
4 changes: 2 additions & 2 deletions sdk/identity/azure_identity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl ClientAssertion for VmClientAssertion {
async fn secret(&self) -> azure_core::Result<String> {
Ok(self
.credential
.get_token(&[&self.scope])
.get_token(&[&self.scope], None)
.await?
.token
.secret()
Expand All @@ -116,7 +116,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)?;

let fic_scope = String::from("your-service-app.com/scope");
let fic_token = client_assertion_credential.get_token(&[&fic_scope]).await?;
let fic_token = client_assertion_credential.get_token(&[&fic_scope], None).await?;
Ok(())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn Error>> {

let credentials = AzureCliCredential::new(None)?;
let res = credentials
.get_token(&["https://management.azure.com/.default"])
.get_token(&["https://management.azure.com/.default"], None)
.await?;
eprintln!("Azure CLI response == {res:?}");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let url = url::Url::parse(&format!("https://management.azure.com/subscriptions/{subscription_id}/providers/Microsoft.Storage/storageAccounts?api-version=2019-06-01"))?;

let access_token = credential
.get_token(&["https://management.azure.com/.default"])
.get_token(&["https://management.azure.com/.default"], None)
.await?;

let response = reqwest::Client::new()
Expand Down
26 changes: 18 additions & 8 deletions sdk/identity/azure_identity/examples/specific_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

use azure_core::{
credentials::{AccessToken, TokenCredential},
credentials::{AccessToken, TokenCredential, TokenRequestOptions},
error::{ErrorKind, ResultExt},
Error,
};
Expand All @@ -25,7 +25,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let url = url::Url::parse(&format!("https://management.azure.com/subscriptions/{subscription_id}/providers/Microsoft.Storage/storageAccounts?api-version=2019-06-01"))?;

let access_token = credential
.get_token(&["https://management.azure.com/.default"])
.get_token(&["https://management.azure.com/.default"], None)
.await?;

let response = reqwest::Client::new()
Expand Down Expand Up @@ -63,15 +63,21 @@ enum SpecificAzureCredentialKind {
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for SpecificAzureCredentialKind {
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
async fn get_token(
&self,
scopes: &[&str],
options: Option<TokenRequestOptions>,
) -> azure_core::Result<AccessToken> {
match self {
#[cfg(not(target_arch = "wasm32"))]
SpecificAzureCredentialKind::AzureCli(credential) => credential.get_token(scopes).await,
SpecificAzureCredentialKind::AzureCli(credential) => {
credential.get_token(scopes, options).await
}
SpecificAzureCredentialKind::ManagedIdentity(credential) => {
credential.get_token(scopes).await
credential.get_token(scopes, options).await
}
SpecificAzureCredentialKind::WorkloadIdentity(credential) => {
credential.get_token(scopes).await
credential.get_token(scopes, options).await
}
}
}
Expand Down Expand Up @@ -133,7 +139,11 @@ impl SpecificAzureCredential {
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for SpecificAzureCredential {
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
self.source.get_token(scopes).await
async fn get_token(
&self,
scopes: &[&str],
options: Option<TokenRequestOptions>,
) -> azure_core::Result<AccessToken> {
self.source.get_token(scopes, options).await
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
validate_scope, validate_tenant_id, TokenCredentialOptions,
};
use azure_core::{
credentials::{AccessToken, Secret, TokenCredential},
credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
error::{Error, ErrorKind},
json::from_json,
process::{new_executor, Executor},
Expand Down Expand Up @@ -105,7 +105,11 @@ impl AzureDeveloperCliCredential {
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for AzureDeveloperCliCredential {
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
async fn get_token(
&self,
scopes: &[&str],
_: Option<TokenRequestOptions>,
) -> azure_core::Result<AccessToken> {
if scopes.is_empty() {
return Err(Error::new(
ErrorKind::Credential,
Expand Down Expand Up @@ -184,7 +188,7 @@ mod tests {
tenant_id,
};
let cred = AzureDeveloperCliCredential::new(Some(options))?;
return cred.get_token(LIVE_TEST_SCOPES).await;
return cred.get_token(LIVE_TEST_SCOPES, None).await;
}

#[tokio::test]
Expand Down Expand Up @@ -227,7 +231,7 @@ mod tests {
};
let cred = AzureDeveloperCliCredential::new(Some(options)).expect("valid credential");
let err = cred
.get_token(LIVE_TEST_SCOPES)
.get_token(LIVE_TEST_SCOPES, None)
.await
.expect_err("expected error");
assert!(matches!(err.kind(), ErrorKind::Credential));
Expand Down
14 changes: 9 additions & 5 deletions sdk/identity/azure_identity/src/azure_pipelines_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
TokenCredentialOptions,
};
use azure_core::{
credentials::{AccessToken, Secret, TokenCredential},
credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
error::ErrorKind,
http::{
headers::{FromHeaders, HeaderName, Headers, AUTHORIZATION, CONTENT_LENGTH},
Expand Down Expand Up @@ -108,8 +108,12 @@ impl AzurePipelinesCredential {
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for AzurePipelinesCredential {
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
self.0.get_token(scopes).await
async fn get_token(
&self,
scopes: &[&str],
options: Option<TokenRequestOptions>,
) -> azure_core::Result<AccessToken> {
self.0.get_token(scopes, options).await
}
}

Expand Down Expand Up @@ -258,7 +262,7 @@ mod tests {
AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options))
.expect("valid AzurePipelinesCredential");
assert!(matches!(
credential.get_token(&["default"]).await,
credential.get_token(&["default"], None).await,
Err(err) if matches!(
err.kind(),
ErrorKind::HttpResponse { status, .. }
Expand Down Expand Up @@ -323,7 +327,7 @@ mod tests {
AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options))
.expect("valid AzurePipelinesCredential");
let secret = credential
.get_token(&["default"])
.get_token(&["default"], None)
.await
.expect("valid response");
assert_eq!(secret.token.secret(), "qux");
Expand Down
Loading
Loading