diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md index 17b2c151a2..c8cea5be82 100644 --- a/fuzz-testing/README.md +++ b/fuzz-testing/README.md @@ -61,7 +61,7 @@ Set appropriate values for `SPARK_HOME`, `SPARK_MASTER`, and `COMET_JAR` environ $SPARK_HOME/bin/spark-submit \ --master $SPARK_MASTER \ --class org.apache.comet.fuzz.Main \ - target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \ + target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \ data --num-files=2 --num-rows=200 --exclude-negative-zero --generate-arrays --generate-structs --generate-maps ``` @@ -77,7 +77,7 @@ Generate random queries that are based on the available test files. $SPARK_HOME/bin/spark-submit \ --master $SPARK_MASTER \ --class org.apache.comet.fuzz.Main \ - target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \ + target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \ queries --num-files=2 --num-queries=500 ``` @@ -88,18 +88,17 @@ Note that the output filename is currently hard-coded as `queries.sql` ```shell $SPARK_HOME/bin/spark-submit \ --master $SPARK_MASTER \ + --conf spark.memory.offHeap.enabled=true \ + --conf spark.memory.offHeap.size=16G \ --conf spark.plugins=org.apache.spark.CometPlugin \ --conf spark.comet.enabled=true \ - --conf spark.comet.exec.enabled=true \ - --conf spark.comet.exec.all.enabled=true \ --conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \ --conf spark.comet.exec.shuffle.enabled=true \ - --conf spark.comet.exec.shuffle.mode=auto \ --jars $COMET_JAR \ --conf spark.driver.extraClassPath=$COMET_JAR \ --conf spark.executor.extraClassPath=$COMET_JAR \ --class org.apache.comet.fuzz.Main \ - target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \ + target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \ run --num-files=2 --filename=queries.sql ``` diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala index 1f81dc7791..b9e63c76a0 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala @@ -87,7 +87,11 @@ object Main { SchemaGenOptions( generateArray = conf.generateData.generateArrays(), generateStruct = conf.generateData.generateStructs(), - generateMap = conf.generateData.generateMaps()), + generateMap = conf.generateData.generateMaps(), + // create two columns of each primitive type so that they can be used in binary + // expressions such as `a + b` and `a < b` + primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes ++ + SchemaGenOptions.defaultPrimitiveTypes), DataGenOptions( allowNull = true, generateNegativeZero = !conf.generateData.excludeNegativeZero())) diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala index 246216840b..74d13f85ee 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala @@ -22,6 +22,32 @@ package org.apache.comet.fuzz import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DataTypes +sealed trait SparkType +case class SparkTypeOneOf(dataTypes: Seq[SparkType]) extends SparkType +case object SparkBooleanType extends SparkType +case object SparkBinaryType extends SparkType +case object SparkStringType extends SparkType +case object SparkIntegralType extends SparkType +case object SparkByteType extends SparkType +case object SparkShortType extends SparkType +case object SparkIntType extends SparkType +case object SparkLongType extends SparkType +case object SparkFloatType extends SparkType +case object SparkDoubleType extends SparkType +case class SparkDecimalType(p: Int, s: Int) extends SparkType +case object SparkNumericType extends SparkType +case object SparkDateType extends SparkType +case object SparkTimestampType extends SparkType +case object SparkDateOrTimestampType extends SparkType +case class SparkArrayType(elementType: SparkType) extends SparkType +case class SparkMapType(keyType: SparkType, valueType: SparkType) extends SparkType +case class SparkStructType(fields: Seq[SparkType]) extends SparkType +case object SparkAnyType extends SparkType + +case class FunctionSignature(inputTypes: Seq[SparkType]) + +case class Function(name: String, signatures: Seq[FunctionSignature]) + object Meta { val dataTypes: Seq[(DataType, Double)] = Seq( @@ -35,100 +61,283 @@ object Meta { (DataTypes.createDecimalType(10, 2), 0.2), (DataTypes.DateType, 0.2), (DataTypes.TimestampType, 0.2), - // TimestampNTZType only in Spark 3.4+ - // (DataTypes.TimestampNTZType, 0.2), + (DataTypes.TimestampNTZType, 0.2), (DataTypes.StringType, 0.2), (DataTypes.BinaryType, 0.1)) - val stringScalarFunc: Seq[Function] = Seq( - Function("substring", 3), - Function("coalesce", 1), - Function("starts_with", 2), - Function("ends_with", 2), - Function("contains", 2), - Function("ascii", 1), - Function("bit_length", 1), - Function("octet_length", 1), - Function("upper", 1), - Function("lower", 1), - Function("chr", 1), - Function("init_cap", 1), - Function("trim", 1), - Function("ltrim", 1), - Function("rtrim", 1), - Function("string_space", 1), - Function("rpad", 2), - Function("rpad", 3), // rpad can have 2 or 3 arguments - Function("hex", 1), - Function("unhex", 1), - Function("xxhash64", 1), - Function("sha1", 1), - // Function("sha2", 1), -- needs a second argument for number of bits - Function("substring", 3), - Function("btrim", 1), - Function("concat_ws", 2), - Function("repeat", 2), - Function("length", 1), - Function("reverse", 1), - Function("instr", 2), - Function("replace", 2), - Function("translate", 2)) - - val dateScalarFunc: Seq[Function] = - Seq(Function("year", 1), Function("hour", 1), Function("minute", 1), Function("second", 1)) + private def createFunctionWithInputTypes(name: String, inputs: Seq[SparkType]): Function = { + Function(name, Seq(FunctionSignature(inputs))) + } + + private def createFunctions(name: String, signatures: Seq[FunctionSignature]): Function = { + Function(name, signatures) + } + private def createUnaryStringFunction(name: String): Function = { + createFunctionWithInputTypes(name, Seq(SparkStringType)) + } + + private def createUnaryNumericFunction(name: String): Function = { + createFunctionWithInputTypes(name, Seq(SparkNumericType)) + } + + // Math expressions (corresponds to mathExpressions in QueryPlanSerde) val mathScalarFunc: Seq[Function] = Seq( - Function("abs", 1), - Function("acos", 1), - Function("asin", 1), - Function("atan", 1), - Function("Atan2", 1), - Function("Cos", 1), - Function("Exp", 2), - Function("Ln", 1), - Function("Log10", 1), - Function("Log2", 1), - Function("Pow", 2), - Function("Round", 1), - Function("Signum", 1), - Function("Sin", 1), - Function("Sqrt", 1), - Function("Tan", 1), - Function("Ceil", 1), - Function("Floor", 1), - Function("bool_and", 1), - Function("bool_or", 1), - Function("bitwise_not", 1)) + createUnaryNumericFunction("abs"), + createUnaryNumericFunction("acos"), + createUnaryNumericFunction("asin"), + createUnaryNumericFunction("atan"), + createFunctionWithInputTypes("atan2", Seq(SparkNumericType, SparkNumericType)), + createUnaryNumericFunction("cos"), + createUnaryNumericFunction("exp"), + createUnaryNumericFunction("expm1"), + createFunctionWithInputTypes("log", Seq(SparkNumericType, SparkNumericType)), + createUnaryNumericFunction("log10"), + createUnaryNumericFunction("log2"), + createFunctionWithInputTypes("pow", Seq(SparkNumericType, SparkNumericType)), + createFunctionWithInputTypes("remainder", Seq(SparkNumericType, SparkNumericType)), + createFunctions( + "round", + Seq( + FunctionSignature(Seq(SparkNumericType)), + FunctionSignature(Seq(SparkNumericType, SparkIntType)))), + createUnaryNumericFunction("signum"), + createUnaryNumericFunction("sin"), + createUnaryNumericFunction("sqrt"), + createUnaryNumericFunction("tan"), + createUnaryNumericFunction("ceil"), + createUnaryNumericFunction("floor"), + createFunctionWithInputTypes("unary_minus", Seq(SparkNumericType))) + // Hash expressions (corresponds to hashExpressions in QueryPlanSerde) + val hashScalarFunc: Seq[Function] = Seq( + createFunctionWithInputTypes("md5", Seq(SparkAnyType)), + createFunctionWithInputTypes("murmur3_hash", Seq(SparkAnyType)), // TODO variadic + createFunctionWithInputTypes("sha2", Seq(SparkAnyType, SparkIntType))) + + // String expressions (corresponds to stringExpressions in QueryPlanSerde) + val stringScalarFunc: Seq[Function] = Seq( + createUnaryStringFunction("ascii"), + createUnaryStringFunction("bit_length"), + createUnaryStringFunction("chr"), + createFunctionWithInputTypes( + "concat", + Seq( + SparkTypeOneOf( + Seq( + SparkStringType, + SparkNumericType, + SparkBinaryType, + SparkArrayType( + SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, SparkBinaryType))))), + SparkTypeOneOf( + Seq( + SparkStringType, + SparkNumericType, + SparkBinaryType, + SparkArrayType( + SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, SparkBinaryType))))))), + createFunctionWithInputTypes("concat_ws", Seq(SparkStringType, SparkStringType)), + createFunctionWithInputTypes("contains", Seq(SparkStringType, SparkStringType)), + createFunctionWithInputTypes("ends_with", Seq(SparkStringType, SparkStringType)), + createFunctionWithInputTypes( + "hex", + Seq(SparkTypeOneOf(Seq(SparkStringType, SparkBinaryType, SparkIntType, SparkLongType)))), + createUnaryStringFunction("init_cap"), + createFunctionWithInputTypes("instr", Seq(SparkStringType, SparkStringType)), + createFunctionWithInputTypes( + "length", + Seq(SparkTypeOneOf(Seq(SparkStringType, SparkBinaryType)))), + createFunctionWithInputTypes("like", Seq(SparkStringType, SparkStringType)), + createUnaryStringFunction("lower"), + createFunctions( + "lpad", + Seq( + FunctionSignature(Seq(SparkStringType, SparkIntegralType)), + FunctionSignature(Seq(SparkStringType, SparkIntegralType, SparkStringType)))), + createUnaryStringFunction("ltrim"), + createUnaryStringFunction("octet_length"), + createFunctions( + "regexp_replace", + Seq( + FunctionSignature(Seq(SparkStringType, SparkStringType, SparkStringType)), + FunctionSignature(Seq(SparkStringType, SparkStringType, SparkStringType, SparkIntType)))), + createFunctionWithInputTypes("repeat", Seq(SparkStringType, SparkIntType)), + createFunctions( + "replace", + Seq( + FunctionSignature(Seq(SparkStringType, SparkStringType)), + FunctionSignature(Seq(SparkStringType, SparkStringType, SparkStringType)))), + createFunctions( + "reverse", + Seq( + FunctionSignature(Seq(SparkStringType)), + FunctionSignature(Seq(SparkArrayType(SparkAnyType))))), + createFunctionWithInputTypes("rlike", Seq(SparkStringType, SparkStringType)), + createFunctions( + "rpad", + Seq( + FunctionSignature(Seq(SparkStringType, SparkIntegralType)), + FunctionSignature(Seq(SparkStringType, SparkIntegralType, SparkStringType)))), + createUnaryStringFunction("rtrim"), + createFunctionWithInputTypes("starts_with", Seq(SparkStringType, SparkStringType)), + createFunctionWithInputTypes("string_space", Seq(SparkIntType)), + createFunctionWithInputTypes("substring", Seq(SparkStringType, SparkIntType, SparkIntType)), + createFunctionWithInputTypes("translate", Seq(SparkStringType, SparkStringType)), + createUnaryStringFunction("trim"), + createUnaryStringFunction("btrim"), + createUnaryStringFunction("unhex"), + createUnaryStringFunction("upper"), + createFunctionWithInputTypes("xxhash64", Seq(SparkAnyType)), // TODO variadic + createFunctionWithInputTypes("sha1", Seq(SparkAnyType))) + + // Conditional expressions (corresponds to conditionalExpressions in QueryPlanSerde) + val conditionalScalarFunc: Seq[Function] = Seq( + createFunctionWithInputTypes("if", Seq(SparkBooleanType, SparkAnyType, SparkAnyType))) + + // Map expressions (corresponds to mapExpressions in QueryPlanSerde) + val mapScalarFunc: Seq[Function] = Seq( + createFunctionWithInputTypes( + "map_extract", + Seq(SparkMapType(SparkAnyType, SparkAnyType), SparkAnyType)), + createFunctionWithInputTypes("map_keys", Seq(SparkMapType(SparkAnyType, SparkAnyType))), + createFunctionWithInputTypes("map_entries", Seq(SparkMapType(SparkAnyType, SparkAnyType))), + createFunctionWithInputTypes("map_values", Seq(SparkMapType(SparkAnyType, SparkAnyType))), + createFunctionWithInputTypes( + "map_from_arrays", + Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType)))) + + // Predicate expressions (corresponds to predicateExpressions in QueryPlanSerde) + val predicateScalarFunc: Seq[Function] = Seq( + createFunctionWithInputTypes("and", Seq(SparkBooleanType, SparkBooleanType)), + createFunctionWithInputTypes("or", Seq(SparkBooleanType, SparkBooleanType)), + createFunctionWithInputTypes("not", Seq(SparkBooleanType)), + createFunctionWithInputTypes("in", Seq(SparkAnyType, SparkAnyType)) + ) // TODO: variadic + + // Struct expressions (corresponds to structExpressions in QueryPlanSerde) + val structScalarFunc: Seq[Function] = Seq( + createFunctionWithInputTypes( + "create_named_struct", + Seq(SparkStringType, SparkAnyType) + ), // TODO: variadic name/value pairs + createFunctionWithInputTypes( + "get_struct_field", + Seq(SparkStructType(Seq(SparkAnyType)), SparkStringType))) + + // Bitwise expressions (corresponds to bitwiseExpressions in QueryPlanSerde) + val bitwiseScalarFunc: Seq[Function] = Seq( + createFunctionWithInputTypes("bitwise_and", Seq(SparkIntegralType, SparkIntegralType)), + createFunctionWithInputTypes("bitwise_count", Seq(SparkIntegralType)), + createFunctionWithInputTypes("bitwise_get", Seq(SparkIntegralType, SparkIntType)), + createFunctionWithInputTypes("bitwise_or", Seq(SparkIntegralType, SparkIntegralType)), + createFunctionWithInputTypes("bitwise_not", Seq(SparkIntegralType)), + createFunctionWithInputTypes("bitwise_xor", Seq(SparkIntegralType, SparkIntegralType)), + createFunctionWithInputTypes("shift_left", Seq(SparkIntegralType, SparkIntType)), + createFunctionWithInputTypes("shift_right", Seq(SparkIntegralType, SparkIntType))) + + // Misc expressions (corresponds to miscExpressions in QueryPlanSerde) val miscScalarFunc: Seq[Function] = - Seq(Function("isnan", 1), Function("isnull", 1), Function("isnotnull", 1)) + Seq( + createFunctionWithInputTypes("isnan", Seq(SparkNumericType)), + createFunctionWithInputTypes("isnull", Seq(SparkAnyType)), + createFunctionWithInputTypes("isnotnull", Seq(SparkAnyType)), + createFunctionWithInputTypes("coalesce", Seq(SparkAnyType, SparkAnyType)) + ) // TODO: variadic + // Array expressions (corresponds to arrayExpressions in QueryPlanSerde) val arrayScalarFunc: Seq[Function] = Seq( - Function("array", 2), - Function("array_remove", 2), - Function("array_insert", 2), - Function("array_contains", 2), - Function("array_intersect", 2), - Function("array_append", 2)) + createFunctionWithInputTypes("array_append", Seq(SparkArrayType(SparkAnyType), SparkAnyType)), + createFunctionWithInputTypes("array_compact", Seq(SparkArrayType(SparkAnyType))), + createFunctionWithInputTypes( + "array_contains", + Seq(SparkArrayType(SparkAnyType), SparkAnyType)), + createFunctionWithInputTypes("array_distinct", Seq(SparkArrayType(SparkAnyType))), + createFunctionWithInputTypes( + "array_except", + Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))), + createFunctionWithInputTypes( + "array_insert", + Seq(SparkArrayType(SparkAnyType), SparkIntType, SparkAnyType)), + createFunctionWithInputTypes( + "array_intersect", + Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))), + createFunctions( + "array_join", + Seq( + FunctionSignature(Seq(SparkArrayType(SparkAnyType), SparkStringType)), + FunctionSignature(Seq(SparkArrayType(SparkAnyType), SparkStringType, SparkStringType)))), + createFunctionWithInputTypes("array_max", Seq(SparkArrayType(SparkAnyType))), + createFunctionWithInputTypes("array_min", Seq(SparkArrayType(SparkAnyType))), + createFunctionWithInputTypes("array_remove", Seq(SparkArrayType(SparkAnyType), SparkAnyType)), + createFunctionWithInputTypes("array_repeat", Seq(SparkAnyType, SparkIntType)), + createFunctionWithInputTypes( + "arrays_overlap", + Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))), + createFunctionWithInputTypes( + "array_union", + Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))), + createFunctionWithInputTypes("array", Seq(SparkAnyType, SparkAnyType)), // TODO: variadic + createFunctionWithInputTypes( + "element_at", + Seq( + SparkTypeOneOf( + Seq(SparkArrayType(SparkAnyType), SparkMapType(SparkAnyType, SparkAnyType))), + SparkAnyType)), + createFunctionWithInputTypes("flatten", Seq(SparkArrayType(SparkArrayType(SparkAnyType)))), + createFunctionWithInputTypes( + "get_array_item", + Seq(SparkArrayType(SparkAnyType), SparkIntType))) - val scalarFunc: Seq[Function] = stringScalarFunc ++ dateScalarFunc ++ - mathScalarFunc ++ miscScalarFunc ++ arrayScalarFunc + // Temporal expressions (corresponds to temporalExpressions in QueryPlanSerde) + val temporalScalarFunc: Seq[Function] = + Seq( + createFunctionWithInputTypes("date_add", Seq(SparkDateType, SparkIntType)), + createFunctionWithInputTypes("date_sub", Seq(SparkDateType, SparkIntType)), + createFunctions( + "from_unixtime", + Seq( + FunctionSignature(Seq(SparkLongType)), + FunctionSignature(Seq(SparkLongType, SparkStringType)))), + createFunctionWithInputTypes("hour", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("minute", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("second", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("trunc", Seq(SparkDateOrTimestampType, SparkStringType)), + createFunctionWithInputTypes("year", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("month", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("day", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("dayofmonth", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("dayofweek", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("weekday", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("dayofyear", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("weekofyear", Seq(SparkDateOrTimestampType)), + createFunctionWithInputTypes("quarter", Seq(SparkDateOrTimestampType))) + + // Combined in same order as exprSerdeMap in QueryPlanSerde + val scalarFunc: Seq[Function] = mathScalarFunc ++ hashScalarFunc ++ stringScalarFunc ++ + conditionalScalarFunc ++ mapScalarFunc ++ predicateScalarFunc ++ + structScalarFunc ++ bitwiseScalarFunc ++ miscScalarFunc ++ arrayScalarFunc ++ + temporalScalarFunc val aggFunc: Seq[Function] = Seq( - Function("min", 1), - Function("max", 1), - Function("count", 1), - Function("avg", 1), - Function("sum", 1), - Function("first", 1), - Function("last", 1), - Function("var_pop", 1), - Function("var_samp", 1), - Function("covar_pop", 1), - Function("covar_samp", 1), - Function("stddev_pop", 1), - Function("stddev_samp", 1), - Function("corr", 2)) + createFunctionWithInputTypes("min", Seq(SparkAnyType)), + createFunctionWithInputTypes("max", Seq(SparkAnyType)), + createFunctionWithInputTypes("count", Seq(SparkAnyType)), + createUnaryNumericFunction("avg"), + createUnaryNumericFunction("sum"), + // first/last are non-deterministic and known to be incompatible with Spark +// createFunctionWithInputTypes("first", Seq(SparkAnyType)), +// createFunctionWithInputTypes("last", Seq(SparkAnyType)), + createUnaryNumericFunction("var_pop"), + createUnaryNumericFunction("var_samp"), + createFunctionWithInputTypes("covar_pop", Seq(SparkNumericType, SparkNumericType)), + createFunctionWithInputTypes("covar_samp", Seq(SparkNumericType, SparkNumericType)), + createUnaryNumericFunction("stddev_pop"), + createUnaryNumericFunction("stddev_samp"), + createFunctionWithInputTypes("corr", Seq(SparkNumericType, SparkNumericType)), + createFunctionWithInputTypes("bit_and", Seq(SparkIntegralType)), + createFunctionWithInputTypes("bit_or", Seq(SparkIntegralType)), + createFunctionWithInputTypes("bit_xor", Seq(SparkIntegralType))) val unaryArithmeticOps: Seq[String] = Seq("+", "-") @@ -137,4 +346,13 @@ object Meta { val comparisonOps: Seq[String] = Seq("=", "<=>", ">", ">=", "<", "<=") + // TODO make this more comprehensive + val comparisonTypes: Seq[SparkType] = Seq( + SparkStringType, + SparkBinaryType, + SparkNumericType, + SparkDateType, + SparkTimestampType, + SparkArrayType(SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, SparkDateType)))) + } diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala index de1117837e..d9e3c147d2 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala @@ -24,7 +24,8 @@ import java.io.{BufferedWriter, FileWriter} import scala.collection.mutable import scala.util.Random -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.types._ object QueryGen { @@ -42,19 +43,25 @@ object QueryGen { val uniqueQueries = mutable.HashSet[String]() for (_ <- 0 until numQueries) { - val sql = r.nextInt().abs % 8 match { - case 0 => generateJoin(r, spark, numFiles) - case 1 => generateAggregate(r, spark, numFiles) - case 2 => generateScalar(r, spark, numFiles) - case 3 => generateCast(r, spark, numFiles) - case 4 => generateUnaryArithmetic(r, spark, numFiles) - case 5 => generateBinaryArithmetic(r, spark, numFiles) - case 6 => generateBinaryComparison(r, spark, numFiles) - case _ => generateConditional(r, spark, numFiles) - } - if (!uniqueQueries.contains(sql)) { - uniqueQueries += sql - w.write(sql + "\n") + try { + val sql = r.nextInt().abs % 8 match { + case 0 => generateJoin(r, spark, numFiles) + case 1 => generateAggregate(r, spark, numFiles) + case 2 => generateScalar(r, spark, numFiles) + case 3 => generateCast(r, spark, numFiles) + case 4 => generateUnaryArithmetic(r, spark, numFiles) + case 5 => generateBinaryArithmetic(r, spark, numFiles) + case 6 => generateBinaryComparison(r, spark, numFiles) + case _ => generateConditional(r, spark, numFiles) + } + if (!uniqueQueries.contains(sql)) { + uniqueQueries += sql + w.write(sql + "\n") + } + } catch { + case e: Exception => + // scalastyle:off + println(s"Failed to generate query: ${e.getMessage}") } } w.close() @@ -65,35 +72,177 @@ object QueryGen { val table = spark.table(tableName) val func = Utils.randomChoice(Meta.aggFunc, r) - val args = Range(0, func.num_args) - .map(_ => Utils.randomChoice(table.columns, r)) + try { + val signature = Utils.randomChoice(func.signatures, r) + val args = signature.inputTypes.map(x => pickRandomColumn(r, table, x)) + + val groupingCols = Range(0, 2).map(_ => Utils.randomChoice(table.columns, r)) + + if (groupingCols.isEmpty) { + s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " + + s"FROM $tableName " + + s"ORDER BY ${args.mkString(", ")};" + } else { + s"SELECT ${groupingCols.mkString(", ")}, ${func.name}(${args.mkString(", ")}) " + + s"FROM $tableName " + + s"GROUP BY ${groupingCols.mkString(",")} " + + s"ORDER BY ${groupingCols.mkString(", ")};" + } + } catch { + case e: Exception => + throw new IllegalStateException( + s"Failed to generate SQL for aggregate function ${func.name}", + e) + } + } + + private def generateScalar(r: Random, spark: SparkSession, numFiles: Int): String = { + val tableName = s"test${r.nextInt(numFiles)}" + val table = spark.table(tableName) - val groupingCols = Range(0, 2).map(_ => Utils.randomChoice(table.columns, r)) + val func = Utils.randomChoice(Meta.scalarFunc, r) + try { + val signature = Utils.randomChoice(func.signatures, r) + val args = signature.inputTypes.map(x => pickRandomColumn(r, table, x)) - if (groupingCols.isEmpty) { + // Example SELECT c0, log(c0) as x FROM test0 s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " + s"FROM $tableName " + s"ORDER BY ${args.mkString(", ")};" - } else { - s"SELECT ${groupingCols.mkString(", ")}, ${func.name}(${args.mkString(", ")}) " + - s"FROM $tableName " + - s"GROUP BY ${groupingCols.mkString(",")} " + - s"ORDER BY ${groupingCols.mkString(", ")};" + } catch { + case e: Exception => + throw new IllegalStateException( + s"Failed to generate SQL for scalar function ${func.name}", + e) } } - private def generateScalar(r: Random, spark: SparkSession, numFiles: Int): String = { - val tableName = s"test${r.nextInt(numFiles)}" - val table = spark.table(tableName) + private def pickRandomColumn(r: Random, df: DataFrame, targetType: SparkType): String = { + targetType match { + case SparkAnyType => + Utils.randomChoice(df.schema.fields, r).name + case SparkBooleanType => + select(r, df, _.dataType == BooleanType) + case SparkByteType => + select(r, df, _.dataType == ByteType) + case SparkShortType => + select(r, df, _.dataType == ShortType) + case SparkIntType => + select(r, df, _.dataType == IntegerType) + case SparkLongType => + select(r, df, _.dataType == LongType) + case SparkFloatType => + select(r, df, _.dataType == FloatType) + case SparkDoubleType => + select(r, df, _.dataType == DoubleType) + case SparkDecimalType(_, _) => + select(r, df, _.dataType.isInstanceOf[DecimalType]) + case SparkIntegralType => + select( + r, + df, + f => + f.dataType == ByteType || f.dataType == ShortType || + f.dataType == IntegerType || f.dataType == LongType) + case SparkNumericType => + select(r, df, f => isNumeric(f.dataType)) + case SparkStringType => + select(r, df, _.dataType == StringType) + case SparkBinaryType => + select(r, df, _.dataType == BinaryType) + case SparkDateType => + select(r, df, _.dataType == DateType) + case SparkTimestampType => + select(r, df, _.dataType == TimestampType) + case SparkDateOrTimestampType => + select(r, df, f => f.dataType == DateType || f.dataType == TimestampType) + case SparkTypeOneOf(choices) => + pickRandomColumn(r, df, Utils.randomChoice(choices, r)) + case SparkArrayType(elementType) => + select( + r, + df, + _.dataType match { + case ArrayType(x, _) if typeMatch(elementType, x) => true + case _ => false + }) + case SparkMapType(keyType, valueType) => + select( + r, + df, + _.dataType match { + case MapType(k, v, _) if typeMatch(keyType, k) && typeMatch(valueType, v) => true + case _ => false + }) + case SparkStructType(fields) => + select( + r, + df, + _.dataType match { + case StructType(structFields) if structFields.length == fields.length => true + case _ => false + }) + case _ => + throw new IllegalStateException(targetType.toString) + } + } - val func = Utils.randomChoice(Meta.scalarFunc, r) - val args = Range(0, func.num_args) - .map(_ => Utils.randomChoice(table.columns, r)) + def pickTwoRandomColumns(r: Random, df: DataFrame, targetType: SparkType): (String, String) = { + val a = pickRandomColumn(r, df, targetType) + val df2 = df.drop(a) + val b = pickRandomColumn(r, df2, targetType) + (a, b) + } - // Example SELECT c0, log(c0) as x FROM test0 - s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " + - s"FROM $tableName " + - s"ORDER BY ${args.mkString(", ")};" + /** Select a random field that matches a predicate */ + private def select(r: Random, df: DataFrame, predicate: StructField => Boolean): String = { + val candidates = df.schema.fields.filter(predicate) + if (candidates.isEmpty) { + throw new IllegalStateException("Failed to find suitable column") + } + Utils.randomChoice(candidates, r).name + } + + private def isNumeric(d: DataType): Boolean = { + d match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: DecimalType => + true + case _ => false + } + } + + private def typeMatch(s: SparkType, d: DataType): Boolean = { + (s, d) match { + case (SparkAnyType, _) => true + case (SparkBooleanType, BooleanType) => true + case (SparkByteType, ByteType) => true + case (SparkShortType, ShortType) => true + case (SparkIntType, IntegerType) => true + case (SparkLongType, LongType) => true + case (SparkFloatType, FloatType) => true + case (SparkDoubleType, DoubleType) => true + case (SparkDecimalType(_, _), _: DecimalType) => true + case (SparkIntegralType, ByteType | ShortType | IntegerType | LongType) => true + case (SparkNumericType, _) if isNumeric(d) => true + case (SparkStringType, StringType) => true + case (SparkBinaryType, BinaryType) => true + case (SparkDateType, DateType) => true + case (SparkTimestampType, TimestampType | TimestampNTZType) => true + case (SparkDateOrTimestampType, DateType | TimestampType | TimestampNTZType) => true + case (SparkArrayType(elementType), ArrayType(elementDataType, _)) => + typeMatch(elementType, elementDataType) + case (SparkMapType(keyType, valueType), MapType(keyDataType, valueDataType, _)) => + typeMatch(keyType, keyDataType) && typeMatch(valueType, valueDataType) + case (SparkStructType(fields), StructType(structFields)) => + fields.length == structFields.length && + fields.zip(structFields.map(_.dataType)).forall { case (sparkType, dataType) => + typeMatch(sparkType, dataType) + } + case (SparkTypeOneOf(choices), _) => + choices.exists(choice => typeMatch(choice, d)) + case _ => false + } } private def generateUnaryArithmetic(r: Random, spark: SparkSession, numFiles: Int): String = { @@ -101,7 +250,7 @@ object QueryGen { val table = spark.table(tableName) val op = Utils.randomChoice(Meta.unaryArithmeticOps, r) - val a = Utils.randomChoice(table.columns, r) + val a = pickRandomColumn(r, table, SparkNumericType) // Example SELECT a, -a FROM test0 s"SELECT $a, $op$a " + @@ -114,8 +263,7 @@ object QueryGen { val table = spark.table(tableName) val op = Utils.randomChoice(Meta.binaryArithmeticOps, r) - val a = Utils.randomChoice(table.columns, r) - val b = Utils.randomChoice(table.columns, r) + val (a, b) = pickTwoRandomColumns(r, table, SparkNumericType) // Example SELECT a, b, a+b FROM test0 s"SELECT $a, $b, $a $op $b " + @@ -128,8 +276,10 @@ object QueryGen { val table = spark.table(tableName) val op = Utils.randomChoice(Meta.comparisonOps, r) - val a = Utils.randomChoice(table.columns, r) - val b = Utils.randomChoice(table.columns, r) + + // pick two columns with the same type + val opType = Utils.randomChoice(Meta.comparisonTypes, r) + val (a, b) = pickTwoRandomColumns(r, table, opType) // Example SELECT a, b, a <=> b FROM test0 s"SELECT $a, $b, $a $op $b " + @@ -142,8 +292,10 @@ object QueryGen { val table = spark.table(tableName) val op = Utils.randomChoice(Meta.comparisonOps, r) - val a = Utils.randomChoice(table.columns, r) - val b = Utils.randomChoice(table.columns, r) + + // pick two columns with the same type + val opType = Utils.randomChoice(Meta.comparisonTypes, r) + val (a, b) = pickTwoRandomColumns(r, table, opType) // Example SELECT a, b, IF(a <=> b, 1, 2), CASE WHEN a <=> b THEN 1 ELSE 2 END FROM test0 s"SELECT $a, $b, $a $op $b, IF($a $op $b, 1, 2), CASE WHEN $a $op $b THEN 1 ELSE 2 END " + @@ -192,5 +344,3 @@ object QueryGen { } } - -case class Function(name: String, num_args: Int) diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala index 8852f4bc17..bcc9f98d06 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala @@ -34,6 +34,11 @@ object QueryRunner { filename: String, showFailedSparkQueries: Boolean = false): Unit = { + var queryCount = 0 + var invalidQueryCount = 0 + var cometFailureCount = 0 + var cometSuccessCount = 0 + val outputFilename = s"results-${System.currentTimeMillis()}.md" // scalastyle:off println println(s"Writing results to $outputFilename") @@ -56,7 +61,7 @@ object QueryRunner { querySource .getLines() .foreach(sql => { - + queryCount += 1 try { // execute with Spark spark.conf.set("spark.comet.enabled", "false") @@ -67,13 +72,11 @@ object QueryRunner { // execute with Comet try { spark.conf.set("spark.comet.enabled", "true") - // complex type support until we support it natively - spark.conf.set("spark.comet.sparkToColumnar.enabled", "true") - spark.conf.set("spark.comet.convert.parquet.enabled", "true") val df = spark.sql(sql) val cometRows = df.collect() val cometPlan = df.queryExecution.executedPlan.toString + var success = true if (sparkRows.length == cometRows.length) { var i = 0 while (i < sparkRows.length) { @@ -82,6 +85,7 @@ object QueryRunner { assert(l.length == r.length) for (j <- 0 until l.length) { if (!same(l(j), r(j))) { + success = false showSQL(w, sql) showPlans(w, sparkPlan, cometPlan) w.write(s"First difference at row $i:\n") @@ -93,16 +97,36 @@ object QueryRunner { i += 1 } } else { + success = false showSQL(w, sql) showPlans(w, sparkPlan, cometPlan) w.write( s"[ERROR] Spark produced ${sparkRows.length} rows and " + s"Comet produced ${cometRows.length} rows.\n") } + + // check that the plan contains Comet operators + if (!cometPlan.contains("Comet")) { + success = false + showSQL(w, sql) + showPlans(w, sparkPlan, cometPlan) + w.write("[ERROR] Comet did not accelerate any part of the plan\n") + } + + if (success) { + cometSuccessCount += 1 + } else { + cometFailureCount += 1 + } + } catch { case e: Exception => // the query worked in Spark but failed in Comet, so this is likely a bug in Comet + cometFailureCount += 1 showSQL(w, sql) + w.write("### Spark Plan\n") + w.write(s"```\n$sparkPlan\n```\n") + w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n") w.write("```\n") val sw = new StringWriter() @@ -119,6 +143,7 @@ object QueryRunner { } catch { case e: Exception => // we expect many generated queries to be invalid + invalidQueryCount += 1 if (showFailedSparkQueries) { showSQL(w, sql) w.write(s"Query failed in Spark: ${e.getMessage}\n") @@ -126,6 +151,11 @@ object QueryRunner { } }) + w.write("# Summary\n") + w.write( + s"Total queries: $queryCount; Invalid queries: $invalidQueryCount; " + + s"Comet failed: $cometFailureCount; Comet succeeded: $cometSuccessCount\n") + } finally { w.close() querySource.close() @@ -133,10 +163,17 @@ object QueryRunner { } private def same(l: Any, r: Any): Boolean = { + if (l == null || r == null) { + return l == null && r == null + } (l, r) match { + case (a: Float, b: Float) if a.isPosInfinity => b.isPosInfinity + case (a: Float, b: Float) if a.isNegInfinity => b.isNegInfinity case (a: Float, b: Float) if a.isInfinity => b.isInfinity case (a: Float, b: Float) if a.isNaN => b.isNaN case (a: Float, b: Float) => (a - b).abs <= 0.000001f + case (a: Double, b: Double) if a.isPosInfinity => b.isPosInfinity + case (a: Double, b: Double) if a.isNegInfinity => b.isNegInfinity case (a: Double, b: Double) if a.isInfinity => b.isInfinity case (a: Double, b: Double) if a.isNaN => b.isNaN case (a: Double, b: Double) => (a - b).abs <= 0.000001 @@ -144,6 +181,10 @@ object QueryRunner { a.length == b.length && a.zip(b).forall(x => same(x._1, x._2)) case (a: WrappedArray[_], b: WrappedArray[_]) => a.length == b.length && a.zip(b).forall(x => same(x._1, x._2)) + case (a: Row, b: Row) => + val aa = a.toSeq + val bb = b.toSeq + aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2)) case (a, b) => a == b } } @@ -153,6 +194,7 @@ object QueryRunner { case null => "NULL" case v: WrappedArray[_] => s"[${v.map(format).mkString(",")}]" case v: Array[Byte] => s"[${v.mkString(",")}]" + case r: Row => formatRow(r) case other => other.toString } }