Skip to content

Commit c6d8793

Browse files
saranrajnkShri Saran Raj N
and
Shri Saran Raj N
authored
Implement FlintJob to handle all query types in warmpool mode (#979)
* Add FlintJob to support queries in warmpool mode Signed-off-by: Shri Saran Raj N <[email protected]> * Revert error message change Signed-off-by: Shri Saran Raj N <[email protected]> * Refactor JobOperator Signed-off-by: Shri Saran Raj N <[email protected]> * WarmpoolEnabled FlintSparkConf doc Signed-off-by: Shri Saran Raj N <[email protected]> --------- Signed-off-by: Shri Saran Raj N <[email protected]> Co-authored-by: Shri Saran Raj N <[email protected]>
1 parent 4783f08 commit c6d8793

File tree

9 files changed

+533
-120
lines changed

9 files changed

+533
-120
lines changed

flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java

+25
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ public final class MetricConstants {
100100
*/
101101
public static final String RESULT_METADATA_WRITE_METRIC_PREFIX = "result.metadata.write";
102102

103+
/**
104+
* Prefix for metrics related to interactive queries
105+
*/
106+
public static final String STATEMENT = "statement";
107+
103108
/**
104109
* Metric name for counting the number of statements currently running.
105110
*/
@@ -140,11 +145,31 @@ public final class MetricConstants {
140145
*/
141146
public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count";
142147

148+
/**
149+
* Metric for tracking the count of jobs failed during query execution
150+
*/
151+
public static final String QUERY_EXECUTION_FAILED_METRIC = "execution.failed.count";
152+
153+
/**
154+
* Metric for tracking the count of jobs failed during query result write
155+
*/
156+
public static final String RESULT_WRITER_FAILED_METRIC = "writer.failed.count";
157+
143158
/**
144159
* Metric for tracking the latency of query execution (start to complete query execution) excluding result write.
145160
*/
146161
public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime";
147162

163+
/**
164+
* Metric for tracking the latency of query result write only (excluding query execution)
165+
*/
166+
public static final String QUERY_RESULT_WRITER_TIME_METRIC = "result.writer.processingTime";
167+
168+
/**
169+
* Metric for tracking the latency of query total execution including result write.
170+
*/
171+
public static final String QUERY_TOTAL_TIME_METRIC = "query.total.processingTime";
172+
148173
/**
149174
* Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX)
150175
*/

flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala

+14
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ object FlintSparkConf {
214214
.doc("Enable external scheduler for index refresh")
215215
.createWithDefault("false")
216216

217+
val WARMPOOL_ENABLED =
218+
FlintConfig("spark.flint.job.warmpoolEnabled")
219+
.doc("Enable warmPool mode for the EMR Job to reduce startup times")
220+
.createWithDefault("false")
221+
222+
val MAX_EXECUTORS_COUNT = FlintConfig("spark.dynamicAllocation.maxExecutors").createOptional()
223+
217224
val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD =
218225
FlintConfig("spark.flint.job.externalScheduler.interval")
219226
.doc("Interval threshold in minutes for external scheduler to trigger index refresh")
@@ -289,6 +296,10 @@ object FlintSparkConf {
289296
FlintConfig(s"spark.flint.job.requestIndex")
290297
.doc("Request index")
291298
.createOptional()
299+
val RESULT_INDEX =
300+
FlintConfig(s"spark.flint.job.resultIndex")
301+
.doc("Result index")
302+
.createOptional()
292303
val EXCLUDE_JOB_IDS =
293304
FlintConfig(s"spark.flint.deployment.excludeJobs")
294305
.doc("Exclude job ids")
@@ -314,6 +325,9 @@ object FlintSparkConf {
314325
val CUSTOM_QUERY_RESULT_WRITER =
315326
FlintConfig("spark.flint.job.customQueryResultWriter")
316327
.createOptional()
328+
val TERMINATE_JVM = FlintConfig("spark.flint.terminateJVM")
329+
.doc("Indicates whether the JVM should be terminated after query execution")
330+
.createWithDefault("true")
317331
}
318332

319333
/**

integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala

+20-3
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ import scala.util.{Failure, Success}
1616
import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
1717
import org.opensearch.action.get.GetRequest
1818
import org.opensearch.client.RequestOptions
19+
import org.opensearch.flint.common.model.FlintStatement
20+
import org.opensearch.flint.common.scheduler.model.LangType
1921
import org.opensearch.flint.core.FlintOptions
2022
import org.opensearch.flint.spark.{FlintSparkIndexMonitor, FlintSparkSuite}
2123
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
2224
import org.scalatest.matchers.must.Matchers.{contain, defined}
2325
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
2426

27+
import org.apache.spark.sql.FlintREPL.currentTimeProvider
2528
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
2629
import org.apache.spark.sql.flint.config.FlintSparkConf
2730
import org.apache.spark.sql.flint.config.FlintSparkConf._
@@ -39,6 +42,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
3942
val appId = "00feq82b752mbt0p"
4043
val dataSourceName = "my_glue1"
4144
val queryId = "testQueryId"
45+
val requestIndex = "testRequestIndex"
4246
var osClient: OSClient = _
4347
val threadLocalFuture = new ThreadLocal[Future[Unit]]()
4448

@@ -83,24 +87,37 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
8387

8488
def createJobOperator(query: String, jobRunId: String): JobOperator = {
8589
val streamingRunningCount = new AtomicInteger(0)
90+
val statementRunningCount = new AtomicInteger(0)
8691

8792
/*
8893
* Because we cannot test from FlintJob.main() for the reason below, we have to configure
8994
* all Spark conf required by Flint code underlying manually.
9095
*/
9196
spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName)
9297
spark.conf.set(JOB_TYPE.key, FlintJobType.STREAMING)
98+
spark.conf.set(REQUEST_INDEX.key, requestIndex)
99+
100+
val flintStatement =
101+
new FlintStatement(
102+
"running",
103+
query,
104+
"",
105+
queryId,
106+
LangType.SQL,
107+
currentTimeProvider.currentEpochMillis(),
108+
Option.empty,
109+
Map.empty)
93110

94111
val job = JobOperator(
95112
appId,
96113
jobRunId,
97114
spark,
98-
query,
99-
queryId,
115+
flintStatement,
100116
dataSourceName,
101117
resultIndex,
102118
FlintJobType.STREAMING,
103-
streamingRunningCount)
119+
streamingRunningCount,
120+
statementRunningCount)
104121
job.terminateJVM = false
105122
job
106123
}

spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala

+57-34
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@ package org.apache.spark.sql
88

99
import java.util.concurrent.atomic.AtomicInteger
1010

11+
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
12+
13+
import org.opensearch.flint.common.model.FlintStatement
14+
import org.opensearch.flint.common.scheduler.model.LangType
1115
import org.opensearch.flint.core.logging.CustomLogging
1216
import org.opensearch.flint.core.metrics.MetricConstants
1317
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
1418

1519
import org.apache.spark.internal.Logging
1620
import org.apache.spark.sql.flint.config.FlintSparkConf
21+
import org.apache.spark.util.ThreadUtils
1722

1823
/**
1924
* Spark SQL Application entrypoint
@@ -26,52 +31,70 @@ import org.apache.spark.sql.flint.config.FlintSparkConf
2631
* write sql query result to given opensearch index
2732
*/
2833
object FlintJob extends Logging with FlintJobExecutor {
34+
private val streamingRunningCount = new AtomicInteger(0)
35+
private val statementRunningCount = new AtomicInteger(0)
36+
2937
def main(args: Array[String]): Unit = {
3038
val (queryOption, resultIndexOption) = parseArgs(args)
3139

3240
val conf = createSparkConf()
33-
val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH)
34-
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
35-
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)
36-
37-
val dataSource = conf.get("spark.flint.datasource.name", "")
38-
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
39-
if (query.isEmpty) {
40-
logAndThrow(s"Query undefined for the ${jobType} job.")
41-
}
42-
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")
43-
44-
if (resultIndexOption.isEmpty) {
45-
logAndThrow("resultIndex is not set")
46-
}
47-
// https://github.com/opensearch-project/opensearch-spark/issues/138
48-
/*
49-
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
50-
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
51-
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
52-
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
53-
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
54-
*/
55-
conf.set("spark.sql.defaultCatalog", dataSource)
56-
configDYNMaxExecutors(conf, jobType)
57-
41+
val sparkSession = createSparkSession(conf)
5842
val applicationId =
5943
environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")
6044
val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown")
45+
val isWarmpoolEnabled = conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean
46+
logInfo(s"isWarmpoolEnabled: ${isWarmpoolEnabled}")
47+
48+
if (!isWarmpoolEnabled) {
49+
val jobType = sparkSession.conf.get("spark.flint.job.type", FlintJobType.BATCH)
50+
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
51+
sparkSession.conf.set(FlintSparkConf.JOB_TYPE.key, jobType)
52+
53+
val dataSource = conf.get("spark.flint.datasource.name", "")
54+
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
55+
if (query.isEmpty) {
56+
logAndThrow(s"Query undefined for the ${jobType} job.")
57+
}
58+
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")
6159

62-
val streamingRunningCount = new AtomicInteger(0)
63-
val jobOperator =
64-
JobOperator(
60+
if (resultIndexOption.isEmpty) {
61+
logAndThrow("resultIndex is not set")
62+
}
63+
64+
configDYNMaxExecutors(conf, jobType)
65+
val flintStatement =
66+
new FlintStatement(
67+
"running",
68+
query,
69+
"",
70+
queryId,
71+
LangType.SQL,
72+
currentTimeProvider.currentEpochMillis(),
73+
Option.empty,
74+
Map.empty)
75+
76+
val jobOperator = createJobOperator(
77+
sparkSession,
6578
applicationId,
6679
jobId,
67-
createSparkSession(conf),
68-
query,
69-
queryId,
80+
flintStatement,
7081
dataSource,
7182
resultIndexOption.get,
7283
jobType,
73-
streamingRunningCount)
74-
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
75-
jobOperator.start()
84+
streamingRunningCount,
85+
statementRunningCount)
86+
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
87+
jobOperator.start()
88+
} else {
89+
// Fetch and execute queries in warm pool mode
90+
val warmpoolJob =
91+
WarmpoolJob(
92+
applicationId,
93+
jobId,
94+
sparkSession,
95+
streamingRunningCount,
96+
statementRunningCount)
97+
warmpoolJob.start()
98+
}
7699
}
77100
}

spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala

+64
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
package org.apache.spark.sql
77

88
import java.util.Locale
9+
import java.util.concurrent.ThreadPoolExecutor
10+
import java.util.concurrent.atomic.AtomicInteger
911

1012
import com.amazonaws.services.glue.model.{AccessDeniedException, AWSGlueException}
1113
import com.amazonaws.services.s3.model.AmazonS3Exception
1214
import com.fasterxml.jackson.databind.ObjectMapper
1315
import org.apache.commons.text.StringEscapeUtils.unescapeJava
1416
import org.opensearch.common.Strings
17+
import org.opensearch.flint.common.model.FlintStatement
1518
import org.opensearch.flint.core.IRestHighLevelClient
1619
import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage}
1720
import org.opensearch.flint.core.metrics.MetricConstants
@@ -20,6 +23,7 @@ import play.api.libs.json._
2023

2124
import org.apache.spark.{SparkConf, SparkException}
2225
import org.apache.spark.internal.Logging
26+
import org.apache.spark.sql.FlintREPL.instantiate
2327
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
2428
import org.apache.spark.sql.catalyst.parser.ParseException
2529
import org.apache.spark.sql.exception.UnrecoverableException
@@ -566,4 +570,64 @@ trait FlintJobExecutor {
566570
}
567571
}
568572
}
573+
574+
def createJobOperator(
575+
spark: SparkSession,
576+
applicationId: String,
577+
jobId: String,
578+
flintStatement: FlintStatement,
579+
dataSource: String,
580+
resultIndex: String,
581+
jobType: String,
582+
streamingRunningCount: AtomicInteger,
583+
statementRunningCount: AtomicInteger): JobOperator = {
584+
// https://github.com/opensearch-project/opensearch-spark/issues/138
585+
/*
586+
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
587+
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
588+
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
589+
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
590+
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
591+
*/
592+
spark.conf.set("spark.sql.defaultCatalog", dataSource)
593+
val jobOperator =
594+
JobOperator(
595+
applicationId,
596+
jobId,
597+
spark,
598+
flintStatement,
599+
dataSource,
600+
resultIndex,
601+
jobType,
602+
streamingRunningCount,
603+
statementRunningCount)
604+
jobOperator
605+
}
606+
607+
def instantiateQueryResultWriter(
608+
spark: SparkSession,
609+
commandContext: CommandContext): QueryResultWriter = {
610+
instantiate(
611+
new QueryResultWriterImpl(commandContext),
612+
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
613+
}
614+
615+
def instantiateStatementExecutionManager(
616+
commandContext: CommandContext): StatementExecutionManager = {
617+
import commandContext._
618+
instantiate(
619+
new StatementExecutionManagerImpl(commandContext),
620+
spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""),
621+
spark,
622+
sessionId)
623+
}
624+
625+
def instantiateSessionManager(
626+
spark: SparkSession,
627+
resultIndexOption: Option[String]): SessionManager = {
628+
instantiate(
629+
new SessionManagerImpl(spark, resultIndexOption),
630+
spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""),
631+
resultIndexOption.getOrElse(""))
632+
}
569633
}

0 commit comments

Comments
 (0)