Skip to content
Merged
24 changes: 19 additions & 5 deletions src/auth/integration-tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,34 @@ pub mod unstable {
Ok(())
}

pub async fn id_token_adc() -> anyhow::Result<()> {
pub async fn id_token_adc(with_impersonation: bool) -> anyhow::Result<()> {
let (project, adc_json) = get_project_and_service_account().await?;
let mut source_sa_json: serde_json::Value = serde_json::from_slice(&adc_json)?;

let mut expected_email = format!("test-sa-creds@{project}.iam.gserviceaccount.com");
let target_audience = "https://example.com";

if with_impersonation {
let target_principal_email =
format!("impersonation-target@{project}.iam.gserviceaccount.com");
source_sa_json = serde_json::json!({
"type": "impersonated_service_account",
"service_account_impersonation_url": format!("https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{target_principal_email}:generateAccessToken"),
"source_credentials": source_sa_json,
});
expected_email = target_principal_email;
}

// Write the ADC to a temporary file
let file = tempfile::NamedTempFile::new().unwrap();
let path = file.into_temp_path();
std::fs::write(&path, adc_json).expect("Unable to write to temporary file.");

let expected_email = format!("test-sa-creds@{project}.iam.gserviceaccount.com");
let target_audience = "https://example.com";
std::fs::write(&path, source_sa_json.to_string())
.expect("Unable to write to temporary file.");

// Create credentials for the principal under test.
let _e = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", path.to_str().unwrap());
let id_token_creds = IDTokenCredentialBuilder::new(target_audience)
.with_include_email()
.build()
.expect("failed to create id token credentials");

Expand Down
13 changes: 12 additions & 1 deletion src/auth/integration-tests/tests/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,18 @@ mod driver {
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
#[serial_test::serial]
async fn run_id_token_adc() -> anyhow::Result<()> {
auth_integration_tests::unstable::id_token_adc().await
let with_impersonation = false;
auth_integration_tests::unstable::id_token_adc(with_impersonation).await
}

#[cfg(all(test, google_cloud_unstable_id_token))]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
#[serial_test::serial]
// verify that include_email via ADC flow is passed down to the impersonated
// builder and email claim is included in the token.
async fn run_id_token_adc_impersonated() -> anyhow::Result<()> {
let with_impersonation = true;
auth_integration_tests::unstable::id_token_adc(with_impersonation).await
}

#[cfg(all(test, google_cloud_unstable_id_token))]
Expand Down
125 changes: 116 additions & 9 deletions src/auth/src/credentials/idtoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ pub(crate) mod dynamic {
/// [AIP-4110]: https://google.aip.dev/auth/4110
pub struct Builder {
target_audience: String,
include_email: bool,
}

impl Builder {
Expand All @@ -193,9 +194,20 @@ impl Builder {
pub fn new<S: Into<String>>(target_audience: S) -> Self {
Self {
target_audience: target_audience.into(),
include_email: false,
}
}

/// Sets whether the ID token should include the `email` claim of the user in the token.
///
/// For some credentials sources like Metadata Server and Impersonated Credentials, the default is
/// to not include the `email` claim. For other sources, they always include it.
/// This option is only relevant for credentials sources that do not include the `email` claim by default.
pub fn with_include_email(mut self) -> Self {
self.include_email = true;
self
}

/// Returns a [IDTokenCredentials] instance with the configured settings.
///
/// # Errors
Expand All @@ -214,30 +226,62 @@ impl Builder {
AdcContents::FallbackToMds => None,
};

build_id_token_credentials(self.target_audience, json_data)
build_id_token_credentials(self.target_audience, self.include_email, json_data)
}
}
enum IDTokenBuilder {
Mds(mds::Builder),
ServiceAccount(service_account::Builder),
Impersonated(impersonated::Builder),
}

fn build_id_token_credentials(
audience: String,
include_email: bool,
json: Option<Value>,
) -> BuildResult<IDTokenCredentials> {
let builder = build_id_token_credentials_internal(audience, include_email, json)?;
match builder {
IDTokenBuilder::Mds(builder) => builder.build(),
IDTokenBuilder::ServiceAccount(builder) => builder.build(),
IDTokenBuilder::Impersonated(builder) => builder.build(),
}
}

fn build_id_token_credentials_internal(
audience: String,
include_email: bool,
json: Option<Value>,
) -> BuildResult<IDTokenBuilder> {
match json {
None => {
// TODO(#3587): pass context that is being built from ADC flow.
mds::Builder::new(audience)
.with_format(mds::Format::Full)
.build()
let format = if include_email {
mds::Format::Full
} else {
mds::Format::Standard
};
Ok(IDTokenBuilder::Mds(
mds::Builder::new(audience).with_format(format),
))
}
Some(json) => {
let cred_type = extract_credential_type(&json)?;
match cred_type {
"authorized_user" => Err(BuilderError::not_supported(format!(
"{cred_type}, use idtoken::user_account::Builder directly."
))),
"service_account" => service_account::Builder::new(audience, json).build(),
"service_account" => Ok(IDTokenBuilder::ServiceAccount(
service_account::Builder::new(audience, json),
)),
"impersonated_service_account" => {
impersonated::Builder::new(audience, json).build()
let builder = impersonated::Builder::new(audience, json);
let builder = if include_email {
builder.with_include_email()
} else {
builder
};
Ok(IDTokenBuilder::Impersonated(builder))
}
"external_account" => {
// never gonna be supported for id tokens
Expand Down Expand Up @@ -289,7 +333,9 @@ fn instant_from_epoch_seconds(secs: u64, now: SystemTime) -> Option<Instant> {
pub(crate) mod tests {
use super::*;
use jsonwebtoken::{Algorithm, EncodingKey, Header};
use mds::Format;
use rsa::pkcs1::EncodeRsaPrivateKey;
use serde_json::json;
use serial_test::parallel;
use std::collections::HashMap;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
Expand Down Expand Up @@ -380,7 +426,7 @@ pub(crate) mod tests {
"refresh_token": "test_refresh_token",
});

let result = build_id_token_credentials(audience, Some(json));
let result = build_id_token_credentials(audience, false, Some(json));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is_not_supported());
Expand Down Expand Up @@ -408,7 +454,7 @@ pub(crate) mod tests {
}
});

let result = build_id_token_credentials(audience, Some(json));
let result = build_id_token_credentials(audience, false, Some(json));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is_not_supported());
Expand All @@ -424,11 +470,72 @@ pub(crate) mod tests {
"type": "unknown_credential_type",
});

let result = build_id_token_credentials(audience, Some(json));
let result = build_id_token_credentials(audience, false, Some(json));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is_unknown_type());
assert!(err.to_string().contains("unknown_credential_type"));
Ok(())
}

#[tokio::test]
#[parallel]
async fn test_build_id_token_include_email_mds() -> TestResult {
let audience = "test_audience".to_string();

// Test with include_email = true and no source credentials (MDS Fallback)
let creds = build_id_token_credentials_internal(audience.clone(), true, None)?;
assert!(matches!(creds, IDTokenBuilder::Mds(_)));
if let IDTokenBuilder::Mds(builder) = creds {
assert!(matches!(builder.format, Some(Format::Full)));
}

// Test with include_email = false and no source credentials (MDS Fallback)
let creds = build_id_token_credentials_internal(audience.clone(), false, None)?;
assert!(matches!(creds, IDTokenBuilder::Mds(_)));
if let IDTokenBuilder::Mds(builder) = creds {
assert!(matches!(builder.format, Some(Format::Standard)));
}

Ok(())
}

#[tokio::test]
#[parallel]
async fn test_build_id_token_include_email_impersonated() -> TestResult {
let audience = "test_audience".to_string();
let json = json!({
"type": "impersonated_service_account",
"source_credentials": {
"type": "service_account",
"project_id": "test-project",
"private_key_id": "test-key-id",
"private_key": "-----BEGIN PRIVATE KEY-----\n-----END PRIVATE KEY-----",
"client_email": "[email protected]",
"client_id": "test-client-id",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/source%40test-project.iam.gserviceaccount.com"
},
"service_account_impersonation_url": "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/[email protected]:generateIdToken"
});

// Test with include_email = true and impersonated source credentials
let creds =
build_id_token_credentials_internal(audience.clone(), true, Some(json.clone()))?;
assert!(matches!(creds, IDTokenBuilder::Impersonated(_)));
if let IDTokenBuilder::Impersonated(builder) = creds {
assert_eq!(builder.include_email, Some(true));
}

// Test with include_email = false and impersonated source credentials
let creds = build_id_token_credentials_internal(audience.clone(), false, Some(json))?;
assert!(matches!(creds, IDTokenBuilder::Impersonated(_)));
if let IDTokenBuilder::Impersonated(builder) = creds {
assert_eq!(builder.include_email, None);
}

Ok(())
}
}
2 changes: 1 addition & 1 deletion src/auth/src/credentials/idtoken/impersonated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ use std::sync::Arc;
pub struct Builder {
source: BuilderSource,
delegates: Option<Vec<String>>,
include_email: Option<bool>,
pub(crate) include_email: Option<bool>,
target_audience: String,
service_account_impersonation_url: Option<String>,
}
Expand Down
2 changes: 1 addition & 1 deletion src/auth/src/credentials/idtoken/mds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl Format {
/// metadata service.
pub struct Builder {
endpoint: Option<String>,
format: Option<Format>,
pub(crate) format: Option<Format>,
licenses: Option<String>,
target_audience: String,
}
Expand Down
Loading