@@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
30
30
import org .apache .spark .sql .catalyst .analysis .TypeCheckResult .DataTypeMismatch
31
31
import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , CodeGenerator , CodegenFallback , ExprCode }
32
32
import org .apache .spark .sql .catalyst .expressions .codegen .Block .BlockHelper
33
- import org .apache .spark .sql .catalyst .expressions .json .{JsonExpressionEvalUtils , JsonExpressionUtils }
33
+ import org .apache .spark .sql .catalyst .expressions .json .{JsonExpressionEvalUtils , JsonExpressionUtils , JsonToStructsEvaluator }
34
34
import org .apache .spark .sql .catalyst .expressions .objects .StaticInvoke
35
- import org .apache .spark .sql .catalyst .expressions .variant .VariantExpressionEvalUtils
36
35
import org .apache .spark .sql .catalyst .json ._
37
36
import org .apache .spark .sql .catalyst .trees .TreePattern .{JSON_TO_STRUCT , TreePattern }
38
37
import org .apache .spark .sql .catalyst .util ._
@@ -639,15 +638,14 @@ case class JsonToStructs(
639
638
variantAllowDuplicateKeys : Boolean = SQLConf .get.getConf(SQLConf .VARIANT_ALLOW_DUPLICATE_KEYS ))
640
639
extends UnaryExpression
641
640
with TimeZoneAwareExpression
642
- with CodegenFallback
643
641
with ExpectsInputTypes
644
642
with NullIntolerant
645
643
with QueryErrorsBase {
646
644
647
645
// The JSON input data might be missing certain fields. We force the nullability
648
646
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
649
647
// can generate incorrect files if values are missing in columns declared as non-nullable.
650
- val nullableSchema = schema.asNullable
648
+ private val nullableSchema : DataType = schema.asNullable
651
649
652
650
override def nullable : Boolean = true
653
651
@@ -680,53 +678,35 @@ case class JsonToStructs(
680
678
messageParameters = Map (" schema" -> toSQLType(nullableSchema)))
681
679
}
682
680
683
- // This converts parsed rows to the desired output by the given schema.
684
- @ transient
685
- lazy val converter = nullableSchema match {
686
- case _ : StructType =>
687
- (rows : Iterator [InternalRow ]) => if (rows.hasNext) rows.next() else null
688
- case _ : ArrayType =>
689
- (rows : Iterator [InternalRow ]) => if (rows.hasNext) rows.next().getArray(0 ) else null
690
- case _ : MapType =>
691
- (rows : Iterator [InternalRow ]) => if (rows.hasNext) rows.next().getMap(0 ) else null
692
- }
693
-
694
- val nameOfCorruptRecord = SQLConf .get.getConf(SQLConf .COLUMN_NAME_OF_CORRUPT_RECORD )
695
- @ transient lazy val parser = {
696
- val parsedOptions = new JSONOptions (options, timeZoneId.get, nameOfCorruptRecord)
697
- val mode = parsedOptions.parseMode
698
- if (mode != PermissiveMode && mode != FailFastMode ) {
699
- throw QueryCompilationErrors .parseModeUnsupportedError(" from_json" , mode)
700
- }
701
- val (parserSchema, actualSchema) = nullableSchema match {
702
- case s : StructType =>
703
- ExprUtils .verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord)
704
- (s, StructType (s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)))
705
- case other =>
706
- (StructType (Array (StructField (" value" , other))), other)
707
- }
708
-
709
- val rawParser = new JacksonParser (actualSchema, parsedOptions, allowArrayAsStructs = false )
710
- val createParser = CreateJacksonParser .utf8String _
711
-
712
- new FailureSafeParser [UTF8String ](
713
- input => rawParser.parse(input, createParser, identity[UTF8String ]),
714
- mode,
715
- parserSchema,
716
- parsedOptions.columnNameOfCorruptRecord)
717
- }
718
-
719
681
override def dataType : DataType = nullableSchema
720
682
721
683
override def withTimeZone (timeZoneId : String ): TimeZoneAwareExpression =
722
684
copy(timeZoneId = Option (timeZoneId))
723
685
724
- override def nullSafeEval (json : Any ): Any = nullableSchema match {
725
- case _ : VariantType =>
726
- VariantExpressionEvalUtils .parseJson(json.asInstanceOf [UTF8String ],
727
- allowDuplicateKeys = variantAllowDuplicateKeys)
728
- case _ =>
729
- converter(parser.parse(json.asInstanceOf [UTF8String ]))
686
+ @ transient
687
+ private val nameOfCorruptRecord = SQLConf .get.getConf(SQLConf .COLUMN_NAME_OF_CORRUPT_RECORD )
688
+
689
+ @ transient
690
+ private lazy val evaluator = new JsonToStructsEvaluator (
691
+ options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys)
692
+
693
+ override def nullSafeEval (json : Any ): Any = evaluator.evaluate(json.asInstanceOf [UTF8String ])
694
+
695
+ override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
696
+ val refEvaluator = ctx.addReferenceObj(" evaluator" , evaluator)
697
+ val eval = child.genCode(ctx)
698
+ val resultType = CodeGenerator .boxedType(dataType)
699
+ val resultTerm = ctx.freshName(" result" )
700
+ ev.copy(code =
701
+ code """
702
+ | ${eval.code}
703
+ | $resultType $resultTerm = ( $resultType) $refEvaluator.evaluate( ${eval.value});
704
+ |boolean ${ev.isNull} = $resultTerm == null;
705
+ | ${CodeGenerator .javaType(dataType)} ${ev.value} = ${CodeGenerator .defaultValue(dataType)};
706
+ |if (! ${ev.isNull}) {
707
+ | ${ev.value} = $resultTerm;
708
+ |}
709
+ | """ .stripMargin)
730
710
}
731
711
732
712
override def inputTypes : Seq [AbstractDataType ] = StringTypeWithCaseAccentSensitivity :: Nil
0 commit comments