Skip to content

Commit 07f1894

Browse files
committed
Add GetTokenOptions
1 parent 12d56d7 commit 07f1894

23 files changed

+247
-116
lines changed

sdk/core/azure_core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ sha2 = { workspace = true, optional = true }
2727
tokio = { workspace = true, optional = true }
2828
tracing.workspace = true
2929
typespec = { workspace = true, features = ["http", "json"] }
30-
typespec_client_core = { workspace = true, features = ["http", "json"] }
30+
typespec_client_core = { workspace = true, features = ["derive", "http", "json"] }
3131

3232
[build-dependencies]
3333
rustc_version.workspace = true

sdk/core/azure_core/src/credentials.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
use serde::{Deserialize, Serialize};
77
use std::{borrow::Cow, fmt::Debug};
8-
use typespec_client_core::date::OffsetDateTime;
8+
use typespec_client_core::{date::OffsetDateTime, fmt::SafeDebug};
99

1010
/// Default Azure authorization scope.
1111
pub static DEFAULT_SCOPE_SUFFIX: &str = "/.default";
@@ -87,10 +87,18 @@ impl AccessToken {
8787
}
8888
}
8989

90+
/// Options for getting a token from a [`TokenCredential`]
91+
#[derive(Clone, Default, SafeDebug)]
92+
pub struct GetTokenOptions;
93+
9094
/// Represents a credential capable of providing an OAuth token.
9195
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
9296
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
9397
pub trait TokenCredential: Send + Sync + Debug {
94-
/// Gets a `AccessToken` for the specified resource
95-
async fn get_token(&self, scopes: &[&str]) -> crate::Result<AccessToken>;
98+
/// Gets an [`AccessToken`] for the specified scopes
99+
async fn get_token(
100+
&self,
101+
scopes: &[&str],
102+
options: Option<GetTokenOptions>,
103+
) -> crate::Result<AccessToken>;
96104
}

sdk/core/azure_core/src/http/policies/bearer_token_policy.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl Policy for BearerTokenCredentialPolicy {
6767
drop(access_token);
6868
let mut access_token = self.access_token.write().await;
6969
if access_token.is_none() {
70-
*access_token = Some(self.credential.get_token(&self.scopes()).await?);
70+
*access_token = Some(self.credential.get_token(&self.scopes(), None).await?);
7171
}
7272
}
7373
Some(token) if should_refresh(&token.expires_on) => {
@@ -79,7 +79,7 @@ impl Policy for BearerTokenCredentialPolicy {
7979
// access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic
8080
if access_token.is_none() || access_token.as_ref().unwrap().expires_on == expires_on
8181
{
82-
match self.credential.get_token(&self.scopes()).await {
82+
match self.credential.get_token(&self.scopes(), None).await {
8383
Ok(new_token) => {
8484
*access_token = Some(new_token);
8585
}
@@ -121,7 +121,7 @@ fn should_refresh(expires_on: &OffsetDateTime) -> bool {
121121
mod tests {
122122
use super::*;
123123
use crate::{
124-
credentials::{Secret, TokenCredential},
124+
credentials::{GetTokenOptions, Secret, TokenCredential},
125125
http::{
126126
headers::{Headers, AUTHORIZATION},
127127
policies::Policy,
@@ -172,7 +172,7 @@ mod tests {
172172
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
173173
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
174174
impl TokenCredential for MockCredential {
175-
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> {
175+
async fn get_token(&self, _: &[&str], _: Option<GetTokenOptions>) -> Result<AccessToken> {
176176
let i = self.calls.fetch_add(1, Ordering::SeqCst);
177177
self.tokens
178178
.get(i)

sdk/core/azure_core_test/src/credentials.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
//! Credentials for live and recorded tests.
55
use azure_core::{
6-
credentials::{AccessToken, Secret, TokenCredential},
6+
credentials::{AccessToken, GetTokenOptions, Secret, TokenCredential},
77
date::OffsetDateTime,
88
error::ErrorKind,
99
};
@@ -17,7 +17,11 @@ pub struct MockCredential;
1717
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
1818
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
1919
impl TokenCredential for MockCredential {
20-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
20+
async fn get_token(
21+
&self,
22+
scopes: &[&str],
23+
_: Option<GetTokenOptions>,
24+
) -> azure_core::Result<AccessToken> {
2125
let token: Secret = format!("TEST TOKEN {}", scopes.join(" ")).into();
2226
let expires_on = OffsetDateTime::now_utc().saturating_add(
2327
Duration::from_secs(60 * 5).try_into().map_err(|err| {

sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async fn generate_authorization(
118118
let token = match auth_token {
119119
Credential::Token(token_credential) => {
120120
let token = token_credential
121-
.get_token(&[&scope_from_url(url)])
121+
.get_token(&[&scope_from_url(url)], None)
122122
.await?
123123
.token
124124
.secret()
@@ -146,7 +146,7 @@ mod tests {
146146
use std::sync::Arc;
147147

148148
use azure_core::{
149-
credentials::{AccessToken, TokenCredential},
149+
credentials::{AccessToken, GetTokenOptions, TokenCredential},
150150
date,
151151
http::Method,
152152
};
@@ -168,7 +168,11 @@ mod tests {
168168
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
169169
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
170170
impl TokenCredential for TestTokenCredential {
171-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
171+
async fn get_token(
172+
&self,
173+
scopes: &[&str],
174+
_: Option<GetTokenOptions>,
175+
) -> azure_core::Result<AccessToken> {
172176
let token = format!("{}+{}", self.0, scopes.join(","));
173177
Ok(AccessToken::new(
174178
token,

sdk/eventhubs/azure_messaging_eventhubs/src/common/connection_manager.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ impl ConnectionManager {
193193
debug!("Get Token.");
194194
let token = self
195195
.credential
196-
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE])
196+
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE], None)
197197
.await?;
198198

199199
debug!("Token for path {path} expires at {}", token.expires_on);
@@ -368,7 +368,7 @@ impl ConnectionManager {
368368

369369
let new_token = self
370370
.credential
371-
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE])
371+
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE], None)
372372
.await?;
373373

374374
// Create an ephemeral session to host the authentication.
@@ -510,7 +510,7 @@ impl ConnectionManager {
510510
mod tests {
511511
use super::*;
512512
use async_trait::async_trait;
513-
use azure_core::{http::Url, Result};
513+
use azure_core::{credentials::GetTokenOptions, http::Url, Result};
514514
use std::sync::Arc;
515515
use time::OffsetDateTime;
516516
use tracing::info;
@@ -551,7 +551,11 @@ mod tests {
551551

552552
#[async_trait]
553553
impl TokenCredential for MockTokenCredential {
554-
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> {
554+
async fn get_token(
555+
&self,
556+
_scopes: &[&str],
557+
_options: Option<GetTokenOptions>,
558+
) -> Result<AccessToken> {
555559
// Simulate a token refresh by incrementing the token get count
556560
// and updating the token expiration time
557561
{

sdk/identity/azure_identity/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ impl ClientAssertion for VmClientAssertion {
9393
async fn secret(&self) -> azure_core::Result<String> {
9494
Ok(self
9595
.credential
96-
.get_token(&[&self.scope])
96+
.get_token(&[&self.scope], None)
9797
.await?
9898
.token
9999
.secret()
@@ -116,7 +116,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
116116
)?;
117117

118118
let fic_scope = String::from("your-service-app.com/scope");
119-
let fic_token = client_assertion_credential.get_token(&[&fic_scope]).await?;
119+
let fic_token = client_assertion_credential.get_token(&[&fic_scope], None).await?;
120120
Ok(())
121121
}
122122

sdk/identity/azure_identity/examples/azure_cli_credentials.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
1313

1414
let credentials = AzureCliCredential::new(None)?;
1515
let res = credentials
16-
.get_token(&["https://management.azure.com/.default"])
16+
.get_token(&["https://management.azure.com/.default"], None)
1717
.await?;
1818
eprintln!("Azure CLI response == {res:?}");
1919

sdk/identity/azure_identity/examples/default_credentials.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
1717
let url = url::Url::parse(&format!("https://management.azure.com/subscriptions/{subscription_id}/providers/Microsoft.Storage/storageAccounts?api-version=2019-06-01"))?;
1818

1919
let access_token = credential
20-
.get_token(&["https://management.azure.com/.default"])
20+
.get_token(&["https://management.azure.com/.default"], None)
2121
.await?;
2222

2323
let response = reqwest::Client::new()

sdk/identity/azure_identity/examples/specific_credential.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Licensed under the MIT License.
33

44
use azure_core::{
5-
credentials::{AccessToken, TokenCredential},
5+
credentials::{AccessToken, GetTokenOptions, TokenCredential},
66
error::{ErrorKind, ResultExt},
77
Error,
88
};
@@ -25,7 +25,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
2525
let url = url::Url::parse(&format!("https://management.azure.com/subscriptions/{subscription_id}/providers/Microsoft.Storage/storageAccounts?api-version=2019-06-01"))?;
2626

2727
let access_token = credential
28-
.get_token(&["https://management.azure.com/.default"])
28+
.get_token(&["https://management.azure.com/.default"], None)
2929
.await?;
3030

3131
let response = reqwest::Client::new()
@@ -63,15 +63,21 @@ enum SpecificAzureCredentialKind {
6363
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
6464
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
6565
impl TokenCredential for SpecificAzureCredentialKind {
66-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
66+
async fn get_token(
67+
&self,
68+
scopes: &[&str],
69+
options: Option<GetTokenOptions>,
70+
) -> azure_core::Result<AccessToken> {
6771
match self {
6872
#[cfg(not(target_arch = "wasm32"))]
69-
SpecificAzureCredentialKind::AzureCli(credential) => credential.get_token(scopes).await,
73+
SpecificAzureCredentialKind::AzureCli(credential) => {
74+
credential.get_token(scopes, options).await
75+
}
7076
SpecificAzureCredentialKind::ManagedIdentity(credential) => {
71-
credential.get_token(scopes).await
77+
credential.get_token(scopes, options).await
7278
}
7379
SpecificAzureCredentialKind::WorkloadIdentity(credential) => {
74-
credential.get_token(scopes).await
80+
credential.get_token(scopes, options).await
7581
}
7682
}
7783
}
@@ -133,7 +139,11 @@ impl SpecificAzureCredential {
133139
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
134140
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
135141
impl TokenCredential for SpecificAzureCredential {
136-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
137-
self.source.get_token(scopes).await
142+
async fn get_token(
143+
&self,
144+
scopes: &[&str],
145+
options: Option<GetTokenOptions>,
146+
) -> azure_core::Result<AccessToken> {
147+
self.source.get_token(scopes, options).await
138148
}
139149
}

sdk/identity/azure_identity/src/azure_developer_cli_credential.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
validate_scope, validate_tenant_id, TokenCredentialOptions,
88
};
99
use azure_core::{
10-
credentials::{AccessToken, Secret, TokenCredential},
10+
credentials::{AccessToken, GetTokenOptions, Secret, TokenCredential},
1111
error::{Error, ErrorKind},
1212
json::from_json,
1313
process::{new_executor, Executor},
@@ -105,7 +105,11 @@ impl AzureDeveloperCliCredential {
105105
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
106106
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
107107
impl TokenCredential for AzureDeveloperCliCredential {
108-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
108+
async fn get_token(
109+
&self,
110+
scopes: &[&str],
111+
_: Option<GetTokenOptions>,
112+
) -> azure_core::Result<AccessToken> {
109113
if scopes.is_empty() {
110114
return Err(Error::new(
111115
ErrorKind::Credential,
@@ -184,7 +188,7 @@ mod tests {
184188
tenant_id,
185189
};
186190
let cred = AzureDeveloperCliCredential::new(Some(options))?;
187-
return cred.get_token(LIVE_TEST_SCOPES).await;
191+
return cred.get_token(LIVE_TEST_SCOPES, None).await;
188192
}
189193

190194
#[tokio::test]
@@ -227,7 +231,7 @@ mod tests {
227231
};
228232
let cred = AzureDeveloperCliCredential::new(Some(options)).expect("valid credential");
229233
let err = cred
230-
.get_token(LIVE_TEST_SCOPES)
234+
.get_token(LIVE_TEST_SCOPES, None)
231235
.await
232236
.expect_err("expected error");
233237
assert!(matches!(err.kind(), ErrorKind::Credential));

sdk/identity/azure_identity/src/azure_pipelines_credential.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::{
66
TokenCredentialOptions,
77
};
88
use azure_core::{
9-
credentials::{AccessToken, Secret, TokenCredential},
9+
credentials::{AccessToken, GetTokenOptions, Secret, TokenCredential},
1010
error::ErrorKind,
1111
http::{
1212
headers::{FromHeaders, HeaderName, Headers, AUTHORIZATION, CONTENT_LENGTH},
@@ -108,8 +108,12 @@ impl AzurePipelinesCredential {
108108
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
109109
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
110110
impl TokenCredential for AzurePipelinesCredential {
111-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
112-
self.0.get_token(scopes).await
111+
async fn get_token(
112+
&self,
113+
scopes: &[&str],
114+
options: Option<GetTokenOptions>,
115+
) -> azure_core::Result<AccessToken> {
116+
self.0.get_token(scopes, options).await
113117
}
114118
}
115119

@@ -258,7 +262,7 @@ mod tests {
258262
AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options))
259263
.expect("valid AzurePipelinesCredential");
260264
assert!(matches!(
261-
credential.get_token(&["default"]).await,
265+
credential.get_token(&["default"], None).await,
262266
Err(err) if matches!(
263267
err.kind(),
264268
ErrorKind::HttpResponse { status, .. }
@@ -323,7 +327,7 @@ mod tests {
323327
AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options))
324328
.expect("valid AzurePipelinesCredential");
325329
let secret = credential
326-
.get_token(&["default"])
330+
.get_token(&["default"], None)
327331
.await
328332
.expect("valid response");
329333
assert_eq!(secret.token.secret(), "qux");

0 commit comments

Comments
 (0)