-
Notifications
You must be signed in to change notification settings - Fork 284
[AutoSparkUT] Propagate SQL query context for decimal-overflow exceptions (SPARK-39190) #14872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow | |
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess | ||
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ImplicitCastInputTypes, Literal, UnsafeProjection, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.rapids._ | ||
|
|
@@ -837,36 +838,52 @@ case class GpuCheckOverflowAfterSum( | |
|
|
||
| override def sql: String = data.sql | ||
|
|
||
| override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { | ||
| withResource(data.columnarEval(batch)) { dataCol => | ||
| val dataBase = dataCol.getBase | ||
| withResource(GpuCast.checkNFixDecimalBounds(dataBase, dataType, !nullOnOverflow)) { | ||
| fixedData => | ||
| withResource(isEmpty.columnarEval(batch)) { isEmptyCol => | ||
| val isEmptyBase = isEmptyCol.getBase | ||
| if (!nullOnOverflow) { | ||
| // ANSI mode | ||
| val problem = withResource(fixedData.isNull) { isNull => | ||
| withResource(isEmptyBase.not()) { notEmpty => | ||
| isNull.and(notEmpty) | ||
| // SPARK-39190: snapshot the Origin at construction so the SQL query context survives the | ||
| // reflection-based plan reconstruction that resets `origin`. GpuDecimalSum.evaluateExpression | ||
| // (and friends) construct this node inside `CurrentOrigin.withOrigin(capturedOrigin)`, so the | ||
| // origin field has the parser-set SQL string at construction time. | ||
| val capturedOrigin: Origin = origin | ||
|
|
||
| // Defensive: re-enter the captured origin around the case-class copy. Spark's `withNewChildren` | ||
| // and `makeCopy` already wrap with `CurrentOrigin.withOrigin(this.origin)`, but if any future | ||
| // reconstruction path (e.g. direct reflection) bypasses that wrap, this override keeps the | ||
| // parser-set SQL context attached to the rebound instance. | ||
| override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = | ||
| CurrentOrigin.withOrigin(capturedOrigin) { | ||
| copy(data = newChildren(0), isEmpty = newChildren(1)) | ||
| } | ||
|
|
||
| override def columnarEval(batch: ColumnarBatch): GpuColumnVector = | ||
| CurrentOrigin.withOrigin(capturedOrigin) { | ||
| withResource(data.columnarEval(batch)) { dataCol => | ||
| val dataBase = dataCol.getBase | ||
| withResource(GpuCast.checkNFixDecimalBounds(dataBase, dataType, !nullOnOverflow)) { | ||
| fixedData => | ||
| withResource(isEmpty.columnarEval(batch)) { isEmptyCol => | ||
| val isEmptyBase = isEmptyCol.getBase | ||
| if (!nullOnOverflow) { | ||
| // ANSI mode | ||
| val problem = withResource(fixedData.isNull) { isNull => | ||
| withResource(isEmptyBase.not()) { notEmpty => | ||
| isNull.and(notEmpty) | ||
| } | ||
| } | ||
| } | ||
| withResource(problem) { problem => | ||
| withResource(problem.any()) { anyProblem => | ||
| if (anyProblem.isValid && anyProblem.getBoolean) { | ||
| throw new ArithmeticException("Overflow in sum of decimals.") | ||
| withResource(problem) { problem => | ||
| withResource(problem.any()) { anyProblem => | ||
| if (anyProblem.isValid && anyProblem.getBoolean) { | ||
| throw new ArithmeticException("Overflow in sum of decimals.") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks — you're right that this throw is missing the query context (CPU |
||
| } | ||
| } | ||
| } | ||
| // No problems fall through... | ||
| } | ||
| withResource(GpuScalar.from(null, dataType)) { nullScale => | ||
| GpuColumnVector.from(isEmptyBase.ifElse(nullScale, fixedData), dataType) | ||
| } | ||
| // No problems fall through... | ||
| } | ||
| withResource(GpuScalar.from(null, dataType)) { nullScale => | ||
| GpuColumnVector.from(isEmptyBase.ifElse(nullScale, fixedData), dataType) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def children: Seq[Expression] = Seq(data, isEmpty) | ||
| } | ||
|
|
@@ -1141,7 +1158,13 @@ abstract class GpuDecimalSum( | |
| Seq(sum, isEmpty) | ||
| } | ||
|
|
||
| override lazy val postUpdate: Seq[Expression] = { | ||
| // SPARK-39190: snapshot the Origin at construction (driver-side, inside RapidsMeta's | ||
| // `CurrentOrigin.withOrigin(wrapped.origin)` wrap so the SQL text is populated). This val | ||
| // survives the reflection-based plan reconstruction that resets `origin` to its default, | ||
| // letting the lazy vals below propagate the SQL context to GpuCheckOverflowAfterSum. | ||
| val capturedOrigin: Origin = origin | ||
|
thirtiseven marked this conversation as resolved.
|
||
|
|
||
| override lazy val postUpdate: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) { | ||
| if (failOnErrorOverride) { | ||
| Seq(GpuCheckOverflowAfterSum(updateSum.attr, updateIsEmpty.attr, dt, !failOnErrorOverride), | ||
| updateIsEmpty.attr) | ||
|
|
@@ -1164,7 +1187,7 @@ abstract class GpuDecimalSum( | |
| Seq(mergeSum, mergeIsEmpty, mergeIsOverflow) | ||
| } | ||
|
|
||
| override lazy val postMerge: Seq[Expression] = { | ||
| override lazy val postMerge: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) { | ||
| if (failOnErrorOverride) { | ||
| Seq( | ||
| GpuCheckOverflowAfterSum(mergeSum.attr, mergeIsEmpty.attr, dt, !failOnErrorOverride), | ||
|
|
@@ -1176,7 +1199,7 @@ abstract class GpuDecimalSum( | |
| } | ||
| } | ||
|
|
||
| override lazy val evaluateExpression: Expression = { | ||
| override lazy val evaluateExpression: Expression = CurrentOrigin.withOrigin(capturedOrigin) { | ||
| GpuCheckOverflowAfterSum(sum, isEmpty, dt, !failOnErrorOverride) | ||
| } | ||
|
|
||
|
|
@@ -1590,9 +1613,15 @@ case class GpuBasicAverage(child: Expression, dt: DataType, failOnError: Boolean | |
| abstract class GpuDecimalAverageBase(child: Expression, sumDataType: DecimalType, | ||
| failOnError: Boolean) | ||
| extends GpuAverage(child, sumDataType, failOnError) { | ||
| override lazy val postUpdate: Seq[Expression] = | ||
| Seq(GpuCheckOverflow(updateSum.attr, sumDataType, nullOnOverflow = !failOnError), | ||
| updateCount.attr) | ||
|
|
||
| // SPARK-39190: snapshot Origin at construction so GpuCheckOverflow nodes built in the lazy | ||
| // vals below inherit the SQL query context. See GpuDecimalSum for the matching pattern. | ||
| val capturedOrigin: Origin = origin | ||
|
|
||
| override lazy val postUpdate: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) { | ||
| Seq(GpuCheckOverflow(updateSum.attr, sumDataType, nullOnOverflow = !failOnError), | ||
| updateCount.attr) | ||
| } | ||
|
|
||
| // To be able to do decimal overflow detection, we need a CudfSum that does **not** ignore nulls. | ||
| // Cudf does not have such an aggregation, so for merge we have to work around that with an extra | ||
|
|
@@ -1604,11 +1633,13 @@ abstract class GpuDecimalAverageBase(child: Expression, sumDataType: DecimalType | |
|
|
||
| override lazy val mergeAggregates: Seq[CudfAggregate] = Seq(mergeSum, mergeCount, mergeIsOverflow) | ||
|
|
||
| override lazy val postMerge: Seq[Expression] = Seq( | ||
| GpuCheckOverflow( | ||
| GpuIf(mergeIsOverflow.attr, GpuLiteral.create(null, sumDataType), mergeSum.attr), | ||
| sumDataType, nullOnOverflow = !failOnError), | ||
| mergeCount.attr) | ||
| override lazy val postMerge: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) { | ||
| Seq( | ||
| GpuCheckOverflow( | ||
| GpuIf(mergeIsOverflow.attr, GpuLiteral.create(null, sumDataType), mergeSum.attr), | ||
| sumDataType, nullOnOverflow = !failOnError), | ||
| mergeCount.attr) | ||
| } | ||
|
|
||
| // This is here to be bug for bug compatible with Spark. They round in the divide and then cast | ||
| // the result to the final value. This loses some data in many cases and we need to be able to | ||
|
|
@@ -1648,7 +1679,7 @@ case class GpuDecimal128Average(child: Expression, dt: DecimalType, failOnError: | |
|
|
||
| override lazy val updateAggregates: Seq[CudfAggregate] = updateSumChunks :+ updateCount | ||
|
|
||
| override lazy val postUpdate: Seq[Expression] = { | ||
| override lazy val postUpdate: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) { | ||
| val assembleExpr = GpuAssembleSumChunks(updateSumChunks.map(_.attr), dt, | ||
| nullOnOverflow = !failOnError, None) | ||
| Seq(GpuCheckOverflow(assembleExpr, dt, nullOnOverflow = !failOnError), updateCount.attr) | ||
|
|
@@ -1670,7 +1701,7 @@ case class GpuDecimal128Average(child: Expression, dt: DecimalType, failOnError: | |
| override lazy val mergeAggregates: Seq[CudfAggregate] = | ||
| mergeSumChunks ++ Seq(mergeCount, mergeIsOverflow) | ||
|
|
||
| override lazy val postMerge: Seq[Expression] = { | ||
| override lazy val postMerge: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) { | ||
| val assembleExpr = GpuAssembleSumChunks(mergeSumChunks.map(_.attr), dt, | ||
| nullOnOverflow = !failOnError, Some(mergeIsOverflow.attr)) | ||
| Seq( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.