Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ val DECORATORS: List<ClientCodegenDecorator> =
CredentialsProviderDecorator(),
RegionDecorator(),
RequireEndpointRules(),
EndpointOverrideMetricDecorator(),
UserAgentDecorator(),
SigV4AuthDecorator(),
HttpRequestChecksumDecorator(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rustsdk

import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginSection
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate

/**
* Decorator that tracks endpoint override business metric when endpoint URL is configured.
*/
class EndpointOverrideMetricDecorator : ClientCodegenDecorator {
override val name: String = "EndpointOverrideMetric"
override val order: Byte = 0

override fun extras(
codegenContext: ClientCodegenContext,
rustCrate: RustCrate,
) {
// Generate the interceptor in the config::endpoint module
rustCrate.withModule(ClientRustModule.Config.endpoint) {
val runtimeConfig = codegenContext.runtimeConfig
val smithyRuntimeApi = RuntimeType.smithyRuntimeApiClient(runtimeConfig)
val smithyTypes = RuntimeType.smithyTypes(runtimeConfig)
val awsRuntime = AwsRuntimeType.awsRuntime(runtimeConfig)
val awsTypes = AwsRuntimeType.awsTypes(runtimeConfig)

rustTemplate(
"""
/// Interceptor that tracks endpoint override business metric.
##[derive(Debug, Default)]
pub(crate) struct EndpointOverrideFeatureTrackerInterceptor;

impl #{Intercept} for EndpointOverrideFeatureTrackerInterceptor {
fn name(&self) -> &'static str {
"EndpointOverrideFeatureTrackerInterceptor"
}

fn read_before_execution(
&self,
_context: &#{BeforeSerializationInterceptorContextRef}<'_>,
cfg: &mut #{ConfigBag},
) -> #{Result}<(), #{BoxError}> {
if cfg.load::<#{EndpointUrl}>().is_some() {
cfg.interceptor_state()
.store_append(#{AwsSdkFeature}::EndpointOverride);
}
#{Ok}(())
}
}
""",
"Intercept" to smithyRuntimeApi.resolve("client::interceptors::Intercept"),
"BeforeSerializationInterceptorContextRef" to
smithyRuntimeApi.resolve("client::interceptors::context::BeforeSerializationInterceptorContextRef"),
"ConfigBag" to smithyTypes.resolve("config_bag::ConfigBag"),
"BoxError" to smithyRuntimeApi.resolve("box_error::BoxError"),
"EndpointUrl" to awsTypes.resolve("endpoint_config::EndpointUrl"),
"AwsSdkFeature" to awsRuntime.resolve("sdk_feature::AwsSdkFeature"),
*preludeScope,
)
}
}

override fun serviceRuntimePluginCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ServiceRuntimePluginCustomization>,
): List<ServiceRuntimePluginCustomization> =
baseCustomizations + listOf(EndpointOverrideFeatureTrackerRegistration(codegenContext))
}

private class EndpointOverrideFeatureTrackerRegistration(
private val codegenContext: ClientCodegenContext,
) : ServiceRuntimePluginCustomization() {
override fun section(section: ServiceRuntimePluginSection) =
writable {
if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) {
section.registerInterceptor(this) {
rustTemplate("crate::config::endpoint::EndpointOverrideFeatureTrackerInterceptor")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest

class EndpointOverrideMetricDecoratorTest {
companion object {
private const val PREFIX = "\$version: \"2\""
val model =
"""
$PREFIX
namespace test

use aws.api#service
use aws.auth#sigv4
use aws.protocols#restJson1
use smithy.rules#endpointRuleSet

@service(sdkId: "dontcare")
@restJson1
@sigv4(name: "dontcare")
@auth([sigv4])
@endpointRuleSet({
"version": "1.0",
"rules": [
{
"type": "endpoint",
"conditions": [
{ "fn": "isSet", "argv": [{ "ref": "Endpoint" }] }
],
"endpoint": { "url": { "ref": "Endpoint" } }
},
{
"type": "endpoint",
"conditions": [],
"endpoint": { "url": "https://example.com" }
}
],
"parameters": {
"Region": { "required": false, "type": "String", "builtIn": "AWS::Region" },
"Endpoint": { "required": false, "type": "String", "builtIn": "SDK::Endpoint" }
}
})
service TestService {
version: "2023-01-01",
operations: [SomeOperation]
}

@http(uri: "/SomeOperation", method: "GET")
@optionalAuth
operation SomeOperation {
input: SomeInput,
output: SomeOutput
}

@input
structure SomeInput {}

@output
structure SomeOutput {}
""".asSmithyModel()
}

@Test
fun `decorator is registered in AwsCodegenDecorator list`() {
val decoratorNames = DECORATORS.map { it.name }
assert(decoratorNames.contains("EndpointOverrideMetric")) {
"EndpointOverrideMetricDecorator should be registered in DECORATORS list. Found: $decoratorNames"
}
}

@Test
fun `endpoint override metric appears when set via SdkConfig`() {
val testParams = awsIntegrationTestParams()

awsSdkIntegrationTest(
model,
testParams,
environment = mapOf("RUSTUP_TOOLCHAIN" to "1.88.0"),
) { context, rustCrate ->
val rc = context.runtimeConfig
val moduleName = context.moduleUseName()

// Enable test-util feature for aws-runtime
rustCrate.mergeFeature(Feature("test-util", true, listOf("aws-runtime/test-util")))

rustCrate.integrationTest("endpoint_override_via_sdk_config") {
tokioTest("metric_tracked_when_endpoint_set_via_sdk_config") {
rustTemplate(
"""
use $moduleName::config::Region;
use $moduleName::Client;
use #{capture_request};
use #{assert_ua_contains_metric_values};

let (http_client, rcvr) = capture_request(None);

// Create SdkConfig with endpoint URL
let sdk_config = #{SdkConfig}::builder()
.region(Region::new("us-east-1"))
.endpoint_url("https://sdk-custom.example.com")
.http_client(http_client.clone())
.build();

// Create client from SdkConfig
let client = Client::new(&sdk_config);

// Make a request
let _ = client.some_operation().send().await;

// Verify the request
let request = rcvr.expect_request();

// Verify endpoint was overridden
let uri = request.uri().to_string();
assert!(
uri.starts_with("https://sdk-custom.example.com"),
"Expected SDK custom endpoint, got: {}",
uri
);

// Verify metric 'N' is present in x-amz-user-agent header
let user_agent = request
.headers()
.get("x-amz-user-agent")
.expect("x-amz-user-agent header missing");

assert_ua_contains_metric_values(user_agent, &["N"]);
""",
*preludeScope,
"capture_request" to RuntimeType.captureRequest(rc),
"assert_ua_contains_metric_values" to AwsRuntimeType.awsRuntime(rc).resolve("user_agent::test_util::assert_ua_contains_metric_values"),
"SdkConfig" to AwsRuntimeType.awsTypes(rc).resolve("sdk_config::SdkConfig"),
)
}
}
}
}

@Test
fun `no endpoint override metric when endpoint not set`() {
val testParams = awsIntegrationTestParams()

awsSdkIntegrationTest(
model,
testParams,
environment = mapOf("RUSTUP_TOOLCHAIN" to "1.88.0"),
) { context, rustCrate ->
val rc = context.runtimeConfig
val moduleName = context.moduleUseName()

// Enable test-util feature for aws-runtime
rustCrate.mergeFeature(Feature("test-util", true, listOf("aws-runtime/test-util")))

rustCrate.integrationTest("no_endpoint_override") {
tokioTest("no_metric_when_endpoint_not_overridden") {
rustTemplate(
"""
use $moduleName::config::{Credentials, Region, SharedCredentialsProvider};
use $moduleName::{Config, Client};
use #{capture_request};

let (http_client, rcvr) = capture_request(None);

// Create config WITHOUT endpoint override
let config = Config::builder()
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.region(Region::new("us-east-1"))
.http_client(http_client.clone())
.build();
let client = Client::from_conf(config);

// Make a request
let _ = client.some_operation().send().await;

// Verify the request
let request = rcvr.expect_request();

// Verify default endpoint was used
let uri = request.uri().to_string();
assert!(
uri.starts_with("https://example.com"),
"Expected default endpoint, got: {}",
uri
);

// Verify metric 'N' is NOT present
let user_agent = request
.headers()
.get("x-amz-user-agent")
.expect("x-amz-user-agent header should be present");

assert!(
!user_agent.contains("m/N"),
"Metric 'N' should NOT be present when endpoint not overridden"
);
""",
*preludeScope,
"capture_request" to RuntimeType.captureRequest(rc),
)
}

// Add a should_panic test to verify assert_ua_contains_metric_values panics when metric is not present
rust("##[should_panic(expected = \"metric values\")]")
tokioTest("assert_panics_when_metric_not_present") {
rustTemplate(
"""
use $moduleName::config::{Credentials, Region, SharedCredentialsProvider};
use $moduleName::{Config, Client};
use #{capture_request};
use #{assert_ua_contains_metric_values};

let (http_client, rcvr) = capture_request(None);

let config = Config::builder()
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.region(Region::new("us-east-1"))
.http_client(http_client.clone())
.build();
let client = Client::from_conf(config);

let _ = client.some_operation().send().await;
let request = rcvr.expect_request();
let user_agent = request.headers().get("x-amz-user-agent").unwrap();

// This should panic because 'N' is not present
assert_ua_contains_metric_values(user_agent, &["N"]);
""",
*preludeScope,
"capture_request" to RuntimeType.captureRequest(rc),
"assert_ua_contains_metric_values" to AwsRuntimeType.awsRuntime(rc).resolve("user_agent::test_util::assert_ua_contains_metric_values"),
)
}
}
}
}
}
2 changes: 2 additions & 0 deletions aws/rust-runtime/aws-runtime/src/sdk_feature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub enum AwsSdkFeature {
SsoLoginDevice,
/// Calling an SSO-OIDC operation as part of the SSO login flow, when using the OAuth2.0 authorization code grant
SsoLoginAuth,
/// Indicates that a custom endpoint URL was configured
EndpointOverride,
}

impl Storable for AwsSdkFeature {
Expand Down
1 change: 1 addition & 0 deletions aws/rust-runtime/aws-runtime/src/user_agent/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ impl ProvideBusinessMetric for AwsSdkFeature {
S3Transfer => Some(BusinessMetric::S3Transfer),
SsoLoginDevice => Some(BusinessMetric::SsoLoginDevice),
SsoLoginAuth => Some(BusinessMetric::SsoLoginAuth),
EndpointOverride => Some(BusinessMetric::EndpointOverride),
}
}
}
Expand Down
Loading