Skip to content

Commit

Permalink
fix: Normalize NaN and zeros for floating number comparison (#953)
Browse files Browse the repository at this point in the history
* fix: Normalize NaN and zeros for floating number comparison

* fix

* fix

* update

* Add doc
  • Loading branch information
viirya authored Sep 22, 2024
1 parent c3023c5 commit 7a6f47f
Show file tree
Hide file tree
Showing 21 changed files with 110 additions and 38 deletions.
8 changes: 8 additions & 0 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ Comet uses the Rust regexp crate for evaluating regular expressions, and this ha
regular expression engine. Comet will fall back to Spark for patterns that are known to produce different results, but
this can be overridden by setting `spark.comet.regexp.allowIncompatible=true`.

## Floating number comparison

Spark normalizes NaN and zero for floating point numbers for several cases. See `NormalizeFloatingNumbers` optimization rule in Spark.
However, one exception is comparison. Spark does not normalize NaN and zero when comparing values
because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the comparison
functions of arrow-rs used by DataFusion do not normalize NaN and zero (e.g., [arrow::compute::kernels::cmp::eq](https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.eq.html#)).
So Comet will add additional normalization expression of NaN and zero for comparison.

## Cast

Cast operations in Comet fall into three levels of support:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.comet._
Expand All @@ -47,6 +48,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DoubleType, FloatType}

import org.apache.comet.CometConf._
import org.apache.comet.CometExplainInfo.getActualPlan
Expand Down Expand Up @@ -840,6 +842,52 @@ class CometSparkSessionExtensions
}
}

def normalizePlan(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case p: ProjectExec =>
val newProjectList = p.projectList.map(normalize(_).asInstanceOf[NamedExpression])
ProjectExec(newProjectList, p.child)
case f: FilterExec =>
val newCondition = normalize(f.condition)
FilterExec(newCondition, f.child)
}
}

// Spark will normalize NaN and zero for floating point numbers for several cases.
// See `NormalizeFloatingNumbers` optimization rule in Spark.
// However, one exception is for comparison operators. Spark does not normalize NaN and zero
// because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the
// comparison functions in arrow-rs do not normalize NaN and zero. So we need to normalize NaN
// and zero for comparison operators in Comet.
def normalize(expr: Expression): Expression = {
expr.transformUp {
case EqualTo(left, right) =>
EqualTo(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case EqualNullSafe(left, right) =>
EqualNullSafe(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case GreaterThan(left, right) =>
GreaterThan(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case GreaterThanOrEqual(left, right) =>
GreaterThanOrEqual(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case LessThan(left, right) =>
LessThan(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case LessThanOrEqual(left, right) =>
LessThanOrEqual(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
}
}

def normalizeNaNAndZero(expr: Expression): Expression = {
expr match {
case _: KnownFloatingPointNormalized => expr
case _ =>
expr.dataType match {
case _: FloatType | _: DoubleType =>
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
case _ => expr
}
}
}

override def apply(plan: SparkPlan): SparkPlan = {
// DataFusion doesn't have ANSI mode. For now we just disable CometExec if ANSI mode is
// enabled.
Expand All @@ -865,7 +913,7 @@ class CometSparkSessionExtensions
plan
}
} else {
var newPlan = transform(plan)
var newPlan = transform(normalizePlan(plan))

// if the plan cannot be run fully natively then explain why (when appropriate
// config is enabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Functions [2]: [sum(CASE WHEN (d_date#12 < 2000-03-11) THEN inv_quantity_on_hand

(22) CometFilter
Input [4]: [w_warehouse_name#7, i_item_id#9, inv_before#15, inv_after#16]
Condition : (CASE WHEN (inv_before#15 > 0) THEN ((cast(inv_after#16 as double) / cast(inv_before#15 as double)) >= 0.666667) END AND CASE WHEN (inv_before#15 > 0) THEN ((cast(inv_after#16 as double) / cast(inv_before#15 as double)) <= 1.5) END)
Condition : (CASE WHEN (inv_before#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(inv_after#16 as double) / cast(inv_before#15 as double)))) >= knownfloatingpointnormalized(normalizenanandzero(0.666667))) END AND CASE WHEN (inv_before#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(inv_after#16 as double) / cast(inv_before#15 as double)))) <= knownfloatingpointnormalized(normalizenanandzero(1.5))) END)

(23) CometTakeOrderedAndProject
Input [4]: [w_warehouse_name#7, i_item_id#9, inv_before#15, inv_after#16]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ ReadSchema: struct<hd_demo_sk:int,hd_buy_potential:string,hd_dep_count:int,hd_ve

(16) CometFilter
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN ((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)) > 1.2) END) AND isnotnull(hd_demo_sk#12))
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)))) > knownfloatingpointnormalized(normalizenanandzero(1.2))) END) AND isnotnull(hd_demo_sk#12))

(17) CometProject
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ Functions [2]: [stddev_samp(cast(inv_quantity_on_hand#3 as double)), avg(inv_qua

(22) CometFilter
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#17, mean#18]
Condition : CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(23) CometProject
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#17, mean#18]
Arguments: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, cov#19], [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, CASE WHEN (mean#18 = 0.0) THEN null ELSE (stdev#17 / mean#18) END AS cov#19]
Arguments: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, cov#19], [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#17 / mean#18) END AS cov#19]

(24) Scan parquet spark_catalog.default.inventory
Output [4]: [inv_item_sk#20, inv_warehouse_sk#21, inv_quantity_on_hand#22, inv_date_sk#23]
Expand Down Expand Up @@ -238,11 +238,11 @@ Functions [2]: [stddev_samp(cast(inv_quantity_on_hand#22 as double)), avg(inv_qu

(41) CometFilter
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, stdev#17, mean#18]
Condition : CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(42) CometProject
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, stdev#17, mean#18]
Arguments: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37], [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#18 AS mean#36, CASE WHEN (mean#18 = 0.0) THEN null ELSE (stdev#17 / mean#18) END AS cov#37]
Arguments: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37], [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#18 AS mean#36, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#17 / mean#18) END AS cov#37]

(43) CometBroadcastExchange
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ Functions [2]: [stddev_samp(cast(inv_quantity_on_hand#3 as double)), avg(inv_qua

(22) CometFilter
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#17, mean#18]
Condition : (CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.0) END AND CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.5) END)
Condition : (CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END AND CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.5))) END)

(23) CometProject
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#17, mean#18]
Arguments: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, cov#19], [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, CASE WHEN (mean#18 = 0.0) THEN null ELSE (stdev#17 / mean#18) END AS cov#19]
Arguments: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, cov#19], [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#17 / mean#18) END AS cov#19]

(24) Scan parquet spark_catalog.default.inventory
Output [4]: [inv_item_sk#20, inv_warehouse_sk#21, inv_quantity_on_hand#22, inv_date_sk#23]
Expand Down Expand Up @@ -238,11 +238,11 @@ Functions [2]: [stddev_samp(cast(inv_quantity_on_hand#22 as double)), avg(inv_qu

(41) CometFilter
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, stdev#17, mean#18]
Condition : CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(42) CometProject
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, stdev#17, mean#18]
Arguments: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37], [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#18 AS mean#36, CASE WHEN (mean#18 = 0.0) THEN null ELSE (stdev#17 / mean#18) END AS cov#37]
Arguments: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37], [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#18 AS mean#36, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#17 / mean#18) END AS cov#37]

(43) CometBroadcastExchange
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ ReadSchema: struct<hd_demo_sk:int,hd_buy_potential:string,hd_dep_count:int,hd_ve

(16) CometFilter
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN ((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)) > 1.0) END) AND isnotnull(hd_demo_sk#12))
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END) AND isnotnull(hd_demo_sk#12))

(17) CometProject
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Results [4]: [w_warehouse_name#7, i_item_id#9, sum(CASE WHEN (d_date#12 < 2000-0

(23) Filter [codegen id : 2]
Input [4]: [w_warehouse_name#7, i_item_id#9, inv_before#19, inv_after#20]
Condition : (CASE WHEN (inv_before#19 > 0) THEN ((cast(inv_after#20 as double) / cast(inv_before#19 as double)) >= 0.666667) END AND CASE WHEN (inv_before#19 > 0) THEN ((cast(inv_after#20 as double) / cast(inv_before#19 as double)) <= 1.5) END)
Condition : (CASE WHEN (inv_before#19 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(inv_after#20 as double) / cast(inv_before#19 as double)))) >= knownfloatingpointnormalized(normalizenanandzero(0.666667))) END AND CASE WHEN (inv_before#19 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(inv_after#20 as double) / cast(inv_before#19 as double)))) <= knownfloatingpointnormalized(normalizenanandzero(1.5))) END)

(24) TakeOrderedAndProject
Input [4]: [w_warehouse_name#7, i_item_id#9, inv_before#19, inv_after#20]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ ReadSchema: struct<hd_demo_sk:int,hd_buy_potential:string,hd_dep_count:int,hd_ve

(16) CometFilter
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN ((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)) > 1.2) END) AND isnotnull(hd_demo_sk#12))
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)))) > knownfloatingpointnormalized(normalizenanandzero(1.2))) END) AND isnotnull(hd_demo_sk#12))

(17) CometProject
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ Results [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stddev_samp(cast(inv_quan

(23) Filter [codegen id : 4]
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#24, mean#25]
Condition : CASE WHEN (mean#25 = 0.0) THEN false ELSE ((stdev#24 / mean#25) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#25)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#24 / mean#25))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(24) Project [codegen id : 4]
Output [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#25, CASE WHEN (mean#25 = 0.0) THEN null ELSE (stdev#24 / mean#25) END AS cov#26]
Output [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#25, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#25)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#24 / mean#25) END AS cov#26]
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#24, mean#25]

(25) Scan parquet spark_catalog.default.inventory
Expand Down Expand Up @@ -254,10 +254,10 @@ Results [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, stddev_samp(cast(inv_qu

(43) Filter [codegen id : 3]
Input [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, stdev#48, mean#49]
Condition : CASE WHEN (mean#49 = 0.0) THEN false ELSE ((stdev#48 / mean#49) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#49)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#48 / mean#49))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(44) Project [codegen id : 3]
Output [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, mean#49, CASE WHEN (mean#49 = 0.0) THEN null ELSE (stdev#48 / mean#49) END AS cov#50]
Output [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, mean#49, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#49)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#48 / mean#49) END AS cov#50]
Input [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, stdev#48, mean#49]

(45) BroadcastExchange
Expand Down
Loading

0 comments on commit 7a6f47f

Please sign in to comment.