diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index 410d896d2..c1e5b7d7f 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -133,4 +133,8 @@ public String getPassword() { public int getSocketTimeoutMillis() { return Integer.parseInt(options.getOrDefault(SOCKET_TIMEOUT_MILLIS, String.valueOf(DEFAULT_SOCKET_TIMEOUT_MILLIS))); } + + public String getSessionIndexARN() { + return options.getOrDefault("spark.repl.flint.ASSUME_ROLE_CREDENTIALS_ROLE_ARN", ""); + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 45aedbaa6..9bfbe9ecf 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -7,6 +7,7 @@ import static org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS; +import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; import com.amazonaws.auth.AWS4Signer; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; @@ -248,7 +249,10 @@ public IRestHighLevelClient createClient() { // Use DefaultAWSCredentialsProviderChain by default. final AtomicReference awsCredentialsProvider = new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); + String providerClass = options.getCustomAwsCredentialsProvider(); + String sessionIndexArn = options.getSessionIndexARN(); + // If a custom provider class is given, validate sessions with it if (!Strings.isNullOrEmpty(providerClass)) { try { Class awsCredentialsProviderClass = Class.forName(providerClass); @@ -258,7 +262,14 @@ public IRestHighLevelClient createClient() { } catch (Exception e) { throw new RuntimeException(e); } + } else if (!Strings.isNullOrEmpty(sessionIndexArn)) { // Otherwise, use a provided ARN if available + // Create a new AssumeRoleAWSCredentialsProvider with the provided ARN + AWSCredentialsProvider assumeRoleProvider = new STSAssumeRoleSessionCredentialsProvider + .Builder(sessionIndexArn, "FlintOpenSearchClient_session") + .build(); + awsCredentialsProvider.set(assumeRoleProvider); } + restClientBuilder.setHttpClientConfigCallback(builder -> { HttpAsyncClientBuilder delegate = builder.addInterceptorLast( diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 359994c56..e7c1386d5 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -77,6 +77,13 @@ object FlintSparkConf { .doc("AWS customAWSCredentialsProvider") .createWithDefault(FlintOptions.DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER) + val REPL_FLINT_ASSUME_SESSION_ROLE_ARN = + FlintConfig("spark.repl.flint.ASSUME_ROLE_CREDENTIALS_ROLE_ARN") + .datasourceOption() + .doc("The role to use for writing state information to the session index. " + + "This is used update and read job results or errors, independent of other write permissions for OpenSearch.") + .createOptional() + val DOC_ID_COLUMN_NAME = FlintConfig("spark.datasource.flint.write.id_name") .datasourceOption() .doc( @@ -230,7 +237,8 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable DATA_SOURCE_NAME, SESSION_ID, REQUEST_INDEX, - EXCLUDE_JOB_IDS) + EXCLUDE_JOB_IDS, + REPL_FLINT_ASSUME_SESSION_ROLE_ARN) .map(conf => (conf.optionKey, conf.readFrom(reader))) .flatMap { case (_, None) => None