Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.jni.{Arithmetic, RoundMode}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.types.{DataType, DecimalType, LongType}

/**
Expand All @@ -43,23 +44,38 @@ case class GpuCheckOverflow(child: Expression,
DType.create(DType.DTypeEnum.DECIMAL32, expectedCudfScale)
}

override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
val base = input.getBase
val rounded = if (resultDType.equals(base.getType)) {
base.incRefCount()
} else {
withResource(Arithmetic.round(base, dataType.scale, RoundMode.HALF_UP)) { rounded =>
if (resultDType.getTypeId != base.getType.getTypeId) {
rounded.castTo(resultDType)
} else {
rounded.incRefCount()
// SPARK-39190: snapshot the Origin at construction so the SQL query context survives the
// reflection-based plan reconstruction (e.g. GpuBindReferences → mapChildren → makeCopy)
// that otherwise resets `origin` to its default.
val capturedOrigin: Origin = origin
Comment thread
thirtiseven marked this conversation as resolved.

// Defensive: re-enter the captured origin around the case-class copy. Spark's `mapChildren` /
// `withNewChildren` already wrap with `CurrentOrigin.withOrigin(this.origin)`, but if any future
// reconstruction path (e.g. direct reflection) bypasses that wrap, this override keeps the
// parser-set SQL context attached to the rebound instance.
override def withNewChildInternal(newChild: Expression): Expression =
Comment thread
thirtiseven marked this conversation as resolved.
CurrentOrigin.withOrigin(capturedOrigin) {
copy(child = newChild)
}

override protected def doColumnar(input: GpuColumnVector): ColumnVector =
CurrentOrigin.withOrigin(capturedOrigin) {
val base = input.getBase
val rounded = if (resultDType.equals(base.getType)) {
base.incRefCount()
} else {
withResource(Arithmetic.round(base, dataType.scale, RoundMode.HALF_UP)) { rounded =>
if (resultDType.getTypeId != base.getType.getTypeId) {
rounded.castTo(resultDType)
} else {
rounded.incRefCount()
}
}
}
withResource(rounded) { rounded =>
GpuCast.checkNFixDecimalBounds(rounded, dataType, !nullOnOverflow)
}
}
withResource(rounded) { rounded =>
GpuCast.checkNFixDecimalBounds(rounded, dataType, !nullOnOverflow)
}
}

override def nullable: Boolean = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ImplicitCastInputTypes, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids._
Expand Down Expand Up @@ -837,36 +838,52 @@ case class GpuCheckOverflowAfterSum(

override def sql: String = data.sql

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
withResource(data.columnarEval(batch)) { dataCol =>
val dataBase = dataCol.getBase
withResource(GpuCast.checkNFixDecimalBounds(dataBase, dataType, !nullOnOverflow)) {
fixedData =>
withResource(isEmpty.columnarEval(batch)) { isEmptyCol =>
val isEmptyBase = isEmptyCol.getBase
if (!nullOnOverflow) {
// ANSI mode
val problem = withResource(fixedData.isNull) { isNull =>
withResource(isEmptyBase.not()) { notEmpty =>
isNull.and(notEmpty)
// SPARK-39190: snapshot the Origin at construction so the SQL query context survives the
// reflection-based plan reconstruction that resets `origin`. GpuDecimalSum.evaluateExpression
// (and friends) construct this node inside `CurrentOrigin.withOrigin(capturedOrigin)`, so the
// origin field has the parser-set SQL string at construction time.
val capturedOrigin: Origin = origin

// Defensive: re-enter the captured origin around the case-class copy. Spark's `withNewChildren`
// and `makeCopy` already wrap with `CurrentOrigin.withOrigin(this.origin)`, but if any future
// reconstruction path (e.g. direct reflection) bypasses that wrap, this override keeps the
// parser-set SQL context attached to the rebound instance.
override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
CurrentOrigin.withOrigin(capturedOrigin) {
copy(data = newChildren(0), isEmpty = newChildren(1))
}

override def columnarEval(batch: ColumnarBatch): GpuColumnVector =
CurrentOrigin.withOrigin(capturedOrigin) {
withResource(data.columnarEval(batch)) { dataCol =>
val dataBase = dataCol.getBase
withResource(GpuCast.checkNFixDecimalBounds(dataBase, dataType, !nullOnOverflow)) {
fixedData =>
withResource(isEmpty.columnarEval(batch)) { isEmptyCol =>
val isEmptyBase = isEmptyCol.getBase
if (!nullOnOverflow) {
// ANSI mode
val problem = withResource(fixedData.isNull) { isNull =>
withResource(isEmptyBase.not()) { notEmpty =>
isNull.and(notEmpty)
}
}
}
withResource(problem) { problem =>
withResource(problem.any()) { anyProblem =>
if (anyProblem.isValid && anyProblem.getBoolean) {
throw new ArithmeticException("Overflow in sum of decimals.")
withResource(problem) { problem =>
withResource(problem.any()) { anyProblem =>
if (anyProblem.isValid && anyProblem.getBoolean) {
throw new ArithmeticException("Overflow in sum of decimals.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 isEmpty-path overflow still lacks SQL query context

The throw new ArithmeticException("Overflow in sum of decimals.") on this line is a plain ArithmeticException — it does not include the SQL query context introduced by SPARK-39190. The CPU equivalent in Spark 3.3 calls QueryExecutionErrors.overflowInSumOfDecimalsError() (a SparkArithmeticException) which carries CurrentOrigin.get().context in the exception message. The capturedOrigin is correctly held in CurrentOrigin at this point (we're inside withOrigin(capturedOrigin)), but the exception construction doesn't read it. The SPARK-39190 test passes because the test's select sum(d) overflow goes through the checkNFixDecimalBounds path (out-of-range result), not the isEmpty path (null from earlier overflow). Any sum overflow that produces null and is detected here would still emit a message without the SQL fragment.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks — you're right that this throw is missing the query context (CPU CheckOverflowInSum.eval at Spark 3.3 calls QueryExecutionErrors.overflowInSumOfDecimalError(queryContext), which carries the SQL fragment). The isEmpty path also isn't exercised by SPARK-39190 (the test's select sum(d) triggers the checkNFixDecimalBounds path, not this one), so I'm leaving it out of scope for this PR — same reason I kept GpuDecimalSum.windowOutput / scanCombine parity in the followup column (see the Review notes block in the PR body). I'll file a separate followup that bundles isEmpty-path + windowOutput + scanCombine context propagation together; thread the fix through RapidsErrorUtils + OriginContextShim.queryContext(CurrentOrigin.get) like the other call sites in this PR.

}
}
}
// No problems fall through...
}
withResource(GpuScalar.from(null, dataType)) { nullScale =>
GpuColumnVector.from(isEmptyBase.ifElse(nullScale, fixedData), dataType)
}
// No problems fall through...
}
withResource(GpuScalar.from(null, dataType)) { nullScale =>
GpuColumnVector.from(isEmptyBase.ifElse(nullScale, fixedData), dataType)
}
}
}
}
}
}

override def children: Seq[Expression] = Seq(data, isEmpty)
}
Expand Down Expand Up @@ -1141,7 +1158,13 @@ abstract class GpuDecimalSum(
Seq(sum, isEmpty)
}

override lazy val postUpdate: Seq[Expression] = {
// SPARK-39190: snapshot the Origin at construction (driver-side, inside RapidsMeta's
// `CurrentOrigin.withOrigin(wrapped.origin)` wrap so the SQL text is populated). This val
// survives the reflection-based plan reconstruction that resets `origin` to its default,
// letting the lazy vals below propagate the SQL context to GpuCheckOverflowAfterSum.
val capturedOrigin: Origin = origin
Comment thread
thirtiseven marked this conversation as resolved.

override lazy val postUpdate: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) {
if (failOnErrorOverride) {
Seq(GpuCheckOverflowAfterSum(updateSum.attr, updateIsEmpty.attr, dt, !failOnErrorOverride),
updateIsEmpty.attr)
Expand All @@ -1164,7 +1187,7 @@ abstract class GpuDecimalSum(
Seq(mergeSum, mergeIsEmpty, mergeIsOverflow)
}

override lazy val postMerge: Seq[Expression] = {
override lazy val postMerge: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) {
if (failOnErrorOverride) {
Seq(
GpuCheckOverflowAfterSum(mergeSum.attr, mergeIsEmpty.attr, dt, !failOnErrorOverride),
Expand All @@ -1176,7 +1199,7 @@ abstract class GpuDecimalSum(
}
}

override lazy val evaluateExpression: Expression = {
override lazy val evaluateExpression: Expression = CurrentOrigin.withOrigin(capturedOrigin) {
GpuCheckOverflowAfterSum(sum, isEmpty, dt, !failOnErrorOverride)
}

Expand Down Expand Up @@ -1590,9 +1613,15 @@ case class GpuBasicAverage(child: Expression, dt: DataType, failOnError: Boolean
abstract class GpuDecimalAverageBase(child: Expression, sumDataType: DecimalType,
failOnError: Boolean)
extends GpuAverage(child, sumDataType, failOnError) {
override lazy val postUpdate: Seq[Expression] =
Seq(GpuCheckOverflow(updateSum.attr, sumDataType, nullOnOverflow = !failOnError),
updateCount.attr)

// SPARK-39190: snapshot Origin at construction so GpuCheckOverflow nodes built in the lazy
// vals below inherit the SQL query context. See GpuDecimalSum for the matching pattern.
val capturedOrigin: Origin = origin

override lazy val postUpdate: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) {
Seq(GpuCheckOverflow(updateSum.attr, sumDataType, nullOnOverflow = !failOnError),
updateCount.attr)
}

// To be able to do decimal overflow detection, we need a CudfSum that does **not** ignore nulls.
// Cudf does not have such an aggregation, so for merge we have to work around that with an extra
Expand All @@ -1604,11 +1633,13 @@ abstract class GpuDecimalAverageBase(child: Expression, sumDataType: DecimalType

override lazy val mergeAggregates: Seq[CudfAggregate] = Seq(mergeSum, mergeCount, mergeIsOverflow)

override lazy val postMerge: Seq[Expression] = Seq(
GpuCheckOverflow(
GpuIf(mergeIsOverflow.attr, GpuLiteral.create(null, sumDataType), mergeSum.attr),
sumDataType, nullOnOverflow = !failOnError),
mergeCount.attr)
override lazy val postMerge: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) {
Seq(
GpuCheckOverflow(
GpuIf(mergeIsOverflow.attr, GpuLiteral.create(null, sumDataType), mergeSum.attr),
sumDataType, nullOnOverflow = !failOnError),
mergeCount.attr)
}

// This is here to be bug for bug compatible with Spark. They round in the divide and then cast
// the result to the final value. This loses some data in many cases and we need to be able to
Expand Down Expand Up @@ -1648,7 +1679,7 @@ case class GpuDecimal128Average(child: Expression, dt: DecimalType, failOnError:

override lazy val updateAggregates: Seq[CudfAggregate] = updateSumChunks :+ updateCount

override lazy val postUpdate: Seq[Expression] = {
override lazy val postUpdate: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) {
val assembleExpr = GpuAssembleSumChunks(updateSumChunks.map(_.attr), dt,
nullOnOverflow = !failOnError, None)
Seq(GpuCheckOverflow(assembleExpr, dt, nullOnOverflow = !failOnError), updateCount.attr)
Expand All @@ -1670,7 +1701,7 @@ case class GpuDecimal128Average(child: Expression, dt: DecimalType, failOnError:
override lazy val mergeAggregates: Seq[CudfAggregate] =
mergeSumChunks ++ Seq(mergeCount, mergeIsOverflow)

override lazy val postMerge: Seq[Expression] = {
override lazy val postMerge: Seq[Expression] = CurrentOrigin.withOrigin(capturedOrigin) {
val assembleExpr = GpuAssembleSumChunks(mergeSumChunks.map(_.attr), dt,
nullOnOverflow = !failOnError, Some(mergeIsOverflow.attr))
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.jni.{Arithmetic, CastStrings, ExceptionWithRowIndex, RoundMode}

import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.rapids.shims.{OriginContextShim, RapidsErrorUtils}
import org.apache.spark.sql.types._

abstract class CudfUnaryMathExpression(name: String) extends GpuUnaryMathExpression(name)
Expand Down Expand Up @@ -866,13 +867,11 @@ object RoundingErrorUtil {
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: ColumnView,
outOfBounds: ColumnView,
toType: DecimalType,
context: String = ""): ArithmeticException = {
toType: DecimalType): ArithmeticException = {
val rowId = withResource(outOfBounds.copyToHost()) { hcv =>
(0L until outOfBounds.getRowCount)
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
Expand All @@ -881,6 +880,9 @@ object RoundingErrorUtil {
val value = withResource(values.getScalarElement(rowId.toInt)) { s =>
s.getBigDecimal
}
RapidsErrorUtils.cannotChangeDecimalPrecisionError(Decimal(value), toType)
// Pass the SQL query context (set on the executor via `CurrentOrigin.withOrigin`
// around `doColumnar`) so the exception message preserves SPARK-39190 parity.
RapidsErrorUtils.cannotChangeDecimalPrecisionError(
Decimal(value), toType, OriginContextShim.queryContext(CurrentOrigin.get))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,5 @@ class RapidsTestSettings extends BackendTestSettings {
.exclude("SPARK-33084: Add jar support Ivy URI in SQL -- jar contains udf class", ADJUST_UT("Replaced by testRapids version that uses testFile() to access Spark test resources instead of getContextClassLoader"))
.exclude("SPARK-33482: Fix FileScan canonicalization", ADJUST_UT("Replaced by testRapids version using V1 sources with AQE and broadcast disabled to assert ReusedExchangeExec directly"))
.exclude("SPARK-36093: RemoveRedundantAliases should not change expression's name", ADJUST_UT("Replaced by testRapids version that checks the partition column name of the GpuInsertIntoHadoopFsRelationCommand"))
.exclude("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should be serialized to executors when WSCG is off", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/14123"))
}
// scalastyle:on line.size.limit