Skip to content

Commit 2a13011

Browse files
panbingkunMaxGekk
authored andcommitted
[SPARK-49966][SQL] Codegen Support for JsonToStructs(from_json)
### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `JsonToStructs`(`from_json`). ### Why are the changes needed? - improve codegen coverage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: JsonFunctionsSuite#`*from_json*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#48466 from panbingkun/SPARK-49966. Authored-by: panbingkun <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent f3b2535 commit 2a13011

File tree

2 files changed

+88
-48
lines changed

2 files changed

+88
-48
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala

+62-2
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@ package org.apache.spark.sql.catalyst.expressions.json
1818

1919
import com.fasterxml.jackson.core.JsonFactory
2020

21-
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JsonInferSchema, JSONOptions}
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.ExprUtils
23+
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
24+
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions}
25+
import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, PermissiveMode}
26+
import org.apache.spark.sql.errors.QueryCompilationErrors
2227
import org.apache.spark.sql.internal.SQLConf
23-
import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
28+
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, VariantType}
2429
import org.apache.spark.unsafe.types.UTF8String
2530
import org.apache.spark.util.Utils
2631

@@ -51,3 +56,58 @@ object JsonExpressionEvalUtils {
5156
UTF8String.fromString(dt.sql)
5257
}
5358
}
59+
60+
class JsonToStructsEvaluator(
61+
options: Map[String, String],
62+
nullableSchema: DataType,
63+
nameOfCorruptRecord: String,
64+
timeZoneId: Option[String],
65+
variantAllowDuplicateKeys: Boolean) extends Serializable {
66+
67+
// This converts parsed rows to the desired output by the given schema.
68+
@transient
69+
private lazy val converter = nullableSchema match {
70+
case _: StructType =>
71+
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
72+
case _: ArrayType =>
73+
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null
74+
case _: MapType =>
75+
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null
76+
}
77+
78+
@transient
79+
private lazy val parser = {
80+
val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord)
81+
val mode = parsedOptions.parseMode
82+
if (mode != PermissiveMode && mode != FailFastMode) {
83+
throw QueryCompilationErrors.parseModeUnsupportedError("from_json", mode)
84+
}
85+
val (parserSchema, actualSchema) = nullableSchema match {
86+
case s: StructType =>
87+
ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord)
88+
(s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)))
89+
case other =>
90+
(StructType(Array(StructField("value", other))), other)
91+
}
92+
93+
val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false)
94+
val createParser = CreateJacksonParser.utf8String _
95+
96+
new FailureSafeParser[UTF8String](
97+
input => rawParser.parse(input, createParser, identity[UTF8String]),
98+
mode,
99+
parserSchema,
100+
parsedOptions.columnNameOfCorruptRecord)
101+
}
102+
103+
final def evaluate(json: UTF8String): Any = {
104+
if (json == null) return null
105+
nullableSchema match {
106+
case _: VariantType =>
107+
VariantExpressionEvalUtils.parseJson(json,
108+
allowDuplicateKeys = variantAllowDuplicateKeys)
109+
case _ =>
110+
converter(parser.parse(json))
111+
}
112+
}
113+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

+26-46
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
3030
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
3131
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
33-
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils}
33+
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator}
3434
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
35-
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
3635
import org.apache.spark.sql.catalyst.json._
3736
import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern}
3837
import org.apache.spark.sql.catalyst.util._
@@ -639,15 +638,14 @@ case class JsonToStructs(
639638
variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS))
640639
extends UnaryExpression
641640
with TimeZoneAwareExpression
642-
with CodegenFallback
643641
with ExpectsInputTypes
644642
with NullIntolerant
645643
with QueryErrorsBase {
646644

647645
// The JSON input data might be missing certain fields. We force the nullability
648646
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
649647
// can generate incorrect files if values are missing in columns declared as non-nullable.
650-
val nullableSchema = schema.asNullable
648+
private val nullableSchema: DataType = schema.asNullable
651649

652650
override def nullable: Boolean = true
653651

@@ -680,53 +678,35 @@ case class JsonToStructs(
680678
messageParameters = Map("schema" -> toSQLType(nullableSchema)))
681679
}
682680

683-
// This converts parsed rows to the desired output by the given schema.
684-
@transient
685-
lazy val converter = nullableSchema match {
686-
case _: StructType =>
687-
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
688-
case _: ArrayType =>
689-
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null
690-
case _: MapType =>
691-
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null
692-
}
693-
694-
val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
695-
@transient lazy val parser = {
696-
val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord)
697-
val mode = parsedOptions.parseMode
698-
if (mode != PermissiveMode && mode != FailFastMode) {
699-
throw QueryCompilationErrors.parseModeUnsupportedError("from_json", mode)
700-
}
701-
val (parserSchema, actualSchema) = nullableSchema match {
702-
case s: StructType =>
703-
ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord)
704-
(s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)))
705-
case other =>
706-
(StructType(Array(StructField("value", other))), other)
707-
}
708-
709-
val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false)
710-
val createParser = CreateJacksonParser.utf8String _
711-
712-
new FailureSafeParser[UTF8String](
713-
input => rawParser.parse(input, createParser, identity[UTF8String]),
714-
mode,
715-
parserSchema,
716-
parsedOptions.columnNameOfCorruptRecord)
717-
}
718-
719681
override def dataType: DataType = nullableSchema
720682

721683
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
722684
copy(timeZoneId = Option(timeZoneId))
723685

724-
override def nullSafeEval(json: Any): Any = nullableSchema match {
725-
case _: VariantType =>
726-
VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String],
727-
allowDuplicateKeys = variantAllowDuplicateKeys)
728-
case _ =>
729-
converter(parser.parse(json.asInstanceOf[UTF8String]))
686+
@transient
687+
private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
688+
689+
@transient
690+
private lazy val evaluator = new JsonToStructsEvaluator(
691+
options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys)
692+
693+
override def nullSafeEval(json: Any): Any = evaluator.evaluate(json.asInstanceOf[UTF8String])
694+
695+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
696+
val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
697+
val eval = child.genCode(ctx)
698+
val resultType = CodeGenerator.boxedType(dataType)
699+
val resultTerm = ctx.freshName("result")
700+
ev.copy(code =
701+
code"""
702+
|${eval.code}
703+
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value});
704+
|boolean ${ev.isNull} = $resultTerm == null;
705+
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
706+
|if (!${ev.isNull}) {
707+
| ${ev.value} = $resultTerm;
708+
|}
709+
|""".stripMargin)
730710
}
731711

732712
override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil

0 commit comments

Comments
 (0)