Skip to content

Commit b69f38b

Browse files
authored
Unescape query from EMR spark submit parameter (#306)
* unescape query from EMR spark submit parameter Signed-off-by: Sean Kao <[email protected]> * scalafmtAll Signed-off-by: Sean Kao <[email protected]> --------- Signed-off-by: Sean Kao <[email protected]>
1 parent 77d0078 commit b69f38b

File tree

4 files changed

+24
-2
lines changed

4 files changed

+24
-2
lines changed

spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ object FlintJob extends Logging with FlintJobExecutor {
5656
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)
5757

5858
val dataSource = conf.get("spark.flint.datasource.name", "")
59-
val query = queryOption.getOrElse(conf.get(FlintSparkConf.QUERY.key, ""))
59+
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
6060
if (query.isEmpty) {
6161
throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.")
6262
}

spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala

+10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException}
1111
import scala.concurrent.duration.{Duration, MINUTES}
1212

1313
import com.amazonaws.services.s3.model.AmazonS3Exception
14+
import org.apache.commons.text.StringEscapeUtils.unescapeJava
1415
import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient}
1516
import org.opensearch.flint.core.metadata.FlintMetadata
1617
import org.opensearch.flint.core.metrics.MetricConstants
@@ -361,6 +362,14 @@ trait FlintJobExecutor {
361362
}
362363
}
363364

365+
/**
366+
* Unescape the query string which is escaped for EMR spark submit parameter parsing. Ref:
367+
* https://github.com/opensearch-project/sql/pull/2587
368+
*/
369+
def unescapeQuery(query: String): String = {
370+
unescapeJava(query)
371+
}
372+
364373
def executeQuery(
365374
spark: SparkSession,
366375
query: String,
@@ -371,6 +380,7 @@ trait FlintJobExecutor {
371380
val startTime = System.currentTimeMillis()
372381
// we have to set job group in the same thread that started the query according to spark doc
373382
spark.sparkContext.setJobGroup(queryId, "Job group for " + queryId, interruptOnCancel = true)
383+
logInfo(s"Executing query: $query")
374384
val result: DataFrame = spark.sql(query)
375385
// Get Data
376386
getFormattedData(

spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
251251
if (defaultQuery.isEmpty) {
252252
throw new IllegalArgumentException("Query undefined for the streaming job.")
253253
}
254-
defaultQuery
254+
unescapeQuery(defaultQuery)
255255
} else ""
256256
}
257257
}

spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala

+12
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ class FlintREPLTest
9595
query shouldBe "SELECT * FROM table"
9696
}
9797

98+
test(
99+
"getQuery should return unescaped default query for streaming job if queryOption is None") {
100+
val queryOption = None
101+
val jobType = "streaming"
102+
val conf = new SparkConf().set(
103+
FlintSparkConf.QUERY.key,
104+
"SELECT \\\"1\\\" UNION SELECT '\\\"1\\\"' UNION SELECT \\\"\\\\\\\"1\\\\\\\"\\\"")
105+
106+
val query = FlintREPL.getQuery(queryOption, jobType, conf)
107+
query shouldBe "SELECT \"1\" UNION SELECT '\"1\"' UNION SELECT \"\\\"1\\\"\""
108+
}
109+
98110
test(
99111
"getQuery should throw IllegalArgumentException if queryOption is None and default query is not defined for streaming job") {
100112
val queryOption = None

0 commit comments

Comments
 (0)