diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 76834c9d63e..027308dde73 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -62,17 +62,17 @@ jobs: fail-fast: true matrix: include: - - spark: 4.0.0 - scala: 2.13.8 + - spark: 4.1.1 + scala: 2.13.17 jdk: '17' - - spark: 3.5.4 - scala: 2.12.18 + - spark: 4.0.2 + scala: 2.13.17 jdk: '17' - - spark: 3.5.0 + - spark: 3.5.8 scala: 2.13.8 jdk: '11' skipTests: '' - - spark: 3.5.0 + - spark: 3.5.8 scala: 2.12.15 jdk: '11' skipTests: '' diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 05782bbae9e..10989cd364a 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -60,7 +60,7 @@ jobs: strategy: matrix: include: - - spark: '4.0.0' + - spark: '4.1.1' scala: '2.13.8' java: '17' python: '3.11' @@ -69,42 +69,9 @@ jobs: java: '17' python: '3.10' - spark: '3.5.0' - scala: '2.12.8' - java: '11' - python: '3.11' - - spark: '3.5.0' - scala: '2.12.8' - java: '11' - python: '3.10' - shapely: '1' - - spark: '3.5.0' - scala: '2.12.8' - java: '11' - python: '3.10' - - spark: '3.5.0' - scala: '2.12.8' - java: '11' - python: '3.9' - - spark: '3.5.0' - scala: '2.12.8' - java: '11' - python: '3.8' - - spark: '3.4.0' - scala: '2.12.8' - java: '11' - python: '3.11' - - spark: '3.4.0' - scala: '2.12.8' - java: '11' - python: '3.10' - - spark: '3.4.0' scala: '2.12.8' java: '11' python: '3.9' - - spark: '3.4.0' - scala: '2.12.8' - java: '11' - python: '3.8' - spark: '3.4.0' scala: '2.12.8' java: '11' @@ -149,9 +116,9 @@ jobs: fi if [ "${SPARK_VERSION:0:1}" == "4" ]; then - # Spark 4.0 requires Python 3.9+, and we remove flink since it conflicts with pyspark 4.0 + # Spark 4.x requires Python 3.10+, and we remove flink since it conflicts with pyspark 4.x uv remove apache-flink --optional flink - uv add "pyspark==4.0.0; python_version >= '3.9'" + uv add "pyspark==${SPARK_VERSION}; python_version >= '3.10'" else # Install specific pyspark version matching matrix uv add pyspark==${SPARK_VERSION} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3be967720ed..51d69a6e040 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -288,7 +288,7 @@ repos: - id: clang-format name: run clang-format description: format C files with clang-format - args: [--style=file:.github/linters/.clang-format] + args: ['--style=file:.github/linters/.clang-format'] types_or: [c] - repo: https://github.com/PyCQA/bandit rev: 1.9.3 diff --git a/docs/community/publish.md b/docs/community/publish.md index 4880e9e40df..348590f4207 100644 --- a/docs/community/publish.md +++ b/docs/community/publish.md @@ -119,13 +119,13 @@ rm -rf $LOCAL_DIR && git clone --depth 1 --branch $TAG $REPO_URL $LOCAL_DIR && c MAVEN_PLUGIN_VERSION="2.3.2" # Define Spark and Scala versions -declare -a SPARK_VERSIONS=("3.4" "3.5" "4.0") +declare -a SPARK_VERSIONS=("3.4" "3.5" "4.0" "4.1") declare -a SCALA_VERSIONS=("2.12" "2.13") # Function to get Java version for Spark version get_java_version() { local spark_version=$1 - if [[ "$spark_version" == "4.0" ]]; then + if [[ "$spark_version" == "4."* ]]; then echo "17" else echo "11" @@ -217,8 +217,8 @@ verify_java_version() { # Iterate through Spark and Scala versions for SPARK in "${SPARK_VERSIONS[@]}"; do for SCALA in "${SCALA_VERSIONS[@]}"; do - # Skip Spark 4.0 + Scala 2.12 combination as it's not supported - if [[ "$SPARK" == "4.0" && "$SCALA" == "2.12" ]]; then + # Skip Spark 4.0+ + Scala 2.12 combination as it's not supported + if [[ "$SPARK" == "4."* && "$SCALA" == "2.12" ]]; then echo "Skipping Spark $SPARK with Scala $SCALA (not supported)" continue fi @@ -286,7 +286,7 @@ mkdir apache-sedona-${SEDONA_VERSION}-bin # Function to get Java version for Spark version get_java_version() { local spark_version=$1 - if [[ "$spark_version" == "4.0" ]]; then + if [[ "$spark_version" == "4."* ]]; then echo "17" else echo "11" @@ -410,6 +410,15 @@ echo "Compiling for Spark 4.0 with Scala 2.13 using Java $JAVA_VERSION..." cd apache-sedona-${SEDONA_VERSION}-src && $MVN_WRAPPER clean && $MVN_WRAPPER install -DskipTests -Dspark=4.0 -Dscala=2.13 && cd .. cp apache-sedona-${SEDONA_VERSION}-src/spark-shaded/target/sedona-*${SEDONA_VERSION}.jar apache-sedona-${SEDONA_VERSION}-bin/ +# Compile for Spark 4.1 with Java 17 +JAVA_VERSION=$(get_java_version "4.1") +MVN_WRAPPER=$(create_mvn_wrapper $JAVA_VERSION) +verify_java_version $MVN_WRAPPER $JAVA_VERSION + +echo "Compiling for Spark 4.1 with Scala 2.13 using Java $JAVA_VERSION..." +cd apache-sedona-${SEDONA_VERSION}-src && $MVN_WRAPPER clean && $MVN_WRAPPER install -DskipTests -Dspark=4.1 -Dscala=2.13 && cd .. +cp apache-sedona-${SEDONA_VERSION}-src/spark-shaded/target/sedona-*${SEDONA_VERSION}.jar apache-sedona-${SEDONA_VERSION}-bin/ + # Clean up Maven wrappers rm -f /tmp/mvn-java11 /tmp/mvn-java17 diff --git a/docs/setup/maven-coordinates.md b/docs/setup/maven-coordinates.md index d57f73eebb4..928200e6984 100644 --- a/docs/setup/maven-coordinates.md +++ b/docs/setup/maven-coordinates.md @@ -133,6 +133,22 @@ The optional GeoTools library is required if you want to use raster operators. V ``` + === "Spark 4.1 and Scala 2.13" + + ```xml + + org.apache.sedona + sedona-spark-shaded-4.1_2.13 + {{ sedona.current_version }} + + + + org.datasyslab + geotools-wrapper + {{ sedona.current_geotools }} + + ``` + !!! abstract "Sedona with Apache Flink" === "Flink 1.12+ and Scala 2.12" @@ -265,6 +281,19 @@ The optional GeoTools library is required if you want to use raster operators. V {{ sedona.current_geotools }} ``` + === "Spark 4.1 and Scala 2.13" + ```xml + + org.apache.sedona + sedona-spark-4.1_2.13 + {{ sedona.current_version }} + + + org.datasyslab + geotools-wrapper + {{ sedona.current_geotools }} + + ``` !!! abstract "Sedona with Apache Flink" diff --git a/docs/setup/platform.md b/docs/setup/platform.md index 9ea2abe0915..c788981ceda 100644 --- a/docs/setup/platform.md +++ b/docs/setup/platform.md @@ -22,28 +22,28 @@ Sedona binary releases are compiled by Java 11/17 and Scala 2.12/2.13 and tested **Java Requirements:** - Spark 3.4 & 3.5: Java 11 -- Spark 4.0: Java 17 +- Spark 4.0 & 4.1: Java 17 **Note:** Java 8 support is dropped since Sedona 1.8.0. Spark 3.3 support is dropped since Sedona 1.8.0. === "Sedona Scala/Java" - | | Spark 3.4| Spark 3.5 | Spark 4.0 | - |:---------:|:---------:|:---------:|:---------:| - | Scala 2.12 |✅ |✅ |✅ | - | Scala 2.13 |✅ |✅ |✅ | + | | Spark 3.4| Spark 3.5 | Spark 4.0 | Spark 4.1 | + |:---------:|:---------:|:---------:|:---------:|:---------:| + | Scala 2.12 |✅ |✅ |✅ | | + | Scala 2.13 |✅ |✅ |✅ |✅ | === "Sedona Python" - | | Spark 3.4 (Scala 2.12)|Spark 3.5 (Scala 2.12)| Spark 4.0 (Scala 2.12)| - |:---------:|:---------:|:---------:|:---------:| - | Python 3.7 | ✅ | ✅ | ✅ | - | Python 3.8 | ✅ | ✅ | ✅ | - | Python 3.9 | ✅ | ✅ | ✅ | - | Python 3.10 | ✅ | ✅ | ✅ | + | | Spark 3.4 (Scala 2.12)|Spark 3.5 (Scala 2.12)| Spark 4.0 (Scala 2.13)| Spark 4.1 (Scala 2.13)| + |:---------:|:---------:|:---------:|:---------:|:---------:| + | Python 3.7 | ✅ | ✅ | | | + | Python 3.8 | ✅ | ✅ | | | + | Python 3.9 | ✅ | ✅ | ✅ | ✅ | + | Python 3.10 | ✅ | ✅ | ✅ | ✅ | === "Sedona R" - | | Spark 3.4 | Spark 3.5 | Spark 4.0 | - |:---------:|:---------:|:---------:|:---------:| - | Scala 2.12 | ✅ | ✅ | ✅ | + | | Spark 3.4 | Spark 3.5 | Spark 4.0 | Spark 4.1 | + |:---------:|:---------:|:---------:|:---------:|:---------:| + | Scala 2.12 | ✅ | ✅ | ✅ | | diff --git a/pom.xml b/pom.xml index 542c6c94659..05ca1cde9c0 100644 --- a/pom.xml +++ b/pom.xml @@ -758,7 +758,30 @@ 2.24.3 2.0.16 - 2.13.12 + 2.13.17 + 2.13 + + + true + + + + sedona-spark-4.1 + + + spark + 4.1 + + + + 4.1.1 + 4.1 + 4 + 3.4.1 + 2.24.3 + 2.0.16 + + 2.13.17 2.13 @@ -775,7 +798,7 @@ false - 2.13.12 + 2.13.17 2.13 -no-java-comments diff --git a/python/pyproject.toml b/python/pyproject.toml index 4d7eda43466..5d2237991e9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -37,13 +37,17 @@ dependencies = [ ] [project.optional-dependencies] -spark = ["pyspark>=3.4.0,<4.1.0"] +spark = [ + "pyspark>=3.4.0,<4.1.0; python_version < '3.10'", + "pyspark>=3.4.0,<4.2.0; python_version >= '3.10'", +] pydeck-map = ["geopandas", "pydeck==0.8.0"] kepler-map = ["geopandas", "keplergl==0.3.2"] flink = ["apache-flink>=1.19.0"] db = ["sedonadb[geopandas]; python_version >= '3.9'"] all = [ - "pyspark>=3.4.0,<4.1.0", + "pyspark>=3.4.0,<4.1.0; python_version < '3.10'", + "pyspark>=3.4.0,<4.2.0; python_version >= '3.10'", "geopandas", "pydeck==0.8.0", "keplergl==0.3.2", @@ -71,7 +75,8 @@ dev = [ # cannot set geopandas>=0.14.4 since it doesn't support python 3.8, so we pin fiona to <1.10.0 "fiona<1.10.0", "pyarrow", - "pyspark>=3.4.0,<4.1.0", + "pyspark>=3.4.0,<4.1.0; python_version < '3.10'", + "pyspark>=3.4.0,<4.2.0; python_version >= '3.10'", "keplergl==0.3.2", "pydeck==0.8.0", "pystac==1.5.0", diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 4eb8ee02888..81dcd5055a9 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -1790,7 +1790,7 @@ def test_dataframe_function( elif isinstance(actual_result, Geography): # self.assert_geometry_almost_equal(expected_result, actual_result.geometry) return - elif isinstance(actual_result, bytearray): + elif isinstance(actual_result, (bytes, bytearray)): actual_result = actual_result.hex() elif isinstance(actual_result, Row): actual_result = { diff --git a/spark/common/pom.xml b/spark/common/pom.xml index 109da51ead2..0210dd825b4 100644 --- a/spark/common/pom.xml +++ b/spark/common/pom.xml @@ -355,5 +355,22 @@ + + sedona-spark-4.1 + + + spark + 4.1 + + + + + org.apache.spark + spark-sql-api_${scala.compat.version} + ${spark.version} + provided + + + diff --git a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/DuplicatesFilter.java b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/DuplicatesFilter.java index 7b28ec2fbed..1e719a5e13a 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/DuplicatesFilter.java +++ b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/DuplicatesFilter.java @@ -20,7 +20,9 @@ import java.util.Iterator; import java.util.List; -import org.apache.commons.collections.iterators.FilterIterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.StreamSupport; import org.apache.commons.lang3.tuple.Pair; import org.apache.log4j.LogManager; import org.apache.log4j.Logger; @@ -57,11 +59,10 @@ public Iterator> call(Integer partitionId, Iterator> geome final List partitionExtents = dedupParamsBroadcast.getValue().getPartitionExtents(); if (partitionId < partitionExtents.size()) { HalfOpenRectangle extent = new HalfOpenRectangle(partitionExtents.get(partitionId)); - return new FilterIterator( - geometryPair, - p -> - !GeomUtils.isDuplicate( - ((Pair) p).getLeft(), ((Pair) p).getRight(), extent)); + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize(geometryPair, Spliterator.ORDERED), false) + .filter(p -> !GeomUtils.isDuplicate(p.getLeft(), p.getRight(), extent)) + .iterator(); } else { log.warn("Didn't find partition extent for this partition: " + partitionId); return geometryPair; diff --git a/spark/common/src/main/java/org/apache/spark/sql/execution/datasources/geoparquet/internal/ParquetColumnVector.java b/spark/common/src/main/java/org/apache/spark/sql/execution/datasources/geoparquet/internal/ParquetColumnVector.java index 870a07b7f6a..a945f72d80e 100644 --- a/spark/common/src/main/java/org/apache/spark/sql/execution/datasources/geoparquet/internal/ParquetColumnVector.java +++ b/spark/common/src/main/java/org/apache/spark/sql/execution/datasources/geoparquet/internal/ParquetColumnVector.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.geoparquet.internal; import com.google.common.base.Preconditions; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -40,6 +41,24 @@ final class ParquetColumnVector { private final List children; private final WritableColumnVector vector; + /** + * Mark the given column vector as all-null / missing. In Spark <= 4.0 this is {@code + * setAllNull()}, in Spark >= 4.1 it was renamed to {@code setMissing()}. + */ + private static void markAllNull(WritableColumnVector v) { + try { + Method m; + try { + m = v.getClass().getMethod("setAllNull"); + } catch (NoSuchMethodException e) { + m = v.getClass().getMethod("setMissing"); + } + m.invoke(v); + } catch (Exception e) { + throw new RuntimeException("Cannot mark column vector as all null", e); + } + } + /** * Repetition & Definition levels These are allocated only for leaf columns; for non-leaf columns, * they simply maintain references to that of the former. @@ -84,7 +103,7 @@ final class ParquetColumnVector { } if (defaultValue == null) { - vector.setAllNull(); + markAllNull(vector); return; } // For Parquet tables whose columns have associated DEFAULT values, this reader must return @@ -139,7 +158,7 @@ final class ParquetColumnVector { // This can happen if all the fields of a struct are missing, in which case we should mark // the struct itself as a missing column if (allChildrenAreMissing) { - vector.setAllNull(); + markAllNull(vector); } } } diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala index 72dad004672..ee62dfaa813 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala @@ -85,7 +85,14 @@ abstract class AbstractCatalog { def registerAll(sparkSession: SparkSession): Unit = { val registry = sparkSession.sessionState.functionRegistry expressions.foreach { case (functionIdentifier, expressionInfo, functionBuilder) => - if (!registry.functionExists(functionIdentifier)) { + val shouldRegister = registry.lookupFunction(functionIdentifier) match { + case Some(existingInfo) => + // Skip if Sedona already registered this function (e.g., SedonaContext.create called + // twice). Overwrite if it's a Spark native function (e.g., Spark 4.1's ST_GeomFromWKB). + !existingInfo.getClassName.startsWith("org.apache.sedona.") + case None => true + } + if (shouldRegister) { registry.registerFunction(functionIdentifier, expressionInfo, functionBuilder) FunctionRegistry.builtin.registerFunction( functionIdentifier, diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/geopackage/model/GeoPackageField.scala b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/geopackage/model/GeoPackageField.scala index 127c56ca52f..7e3dba3ad32 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/geopackage/model/GeoPackageField.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/geopackage/model/GeoPackageField.scala @@ -33,7 +33,7 @@ case class GeoPackageField(name: String, dataType: String, isNullable: Boolean) StructField(name, StringType) case startsWith: String if startsWith.startsWith(GeoPackageType.BLOB) => { if (tableType == TableType.TILES) { - return StructField(name, RasterUDT) + return StructField(name, RasterUDT()) } StructField(name, BinaryType) @@ -41,14 +41,14 @@ case class GeoPackageField(name: String, dataType: String, isNullable: Boolean) case GeoPackageType.INTEGER | GeoPackageType.INT | GeoPackageType.SMALLINT | GeoPackageType.TINY_INT | GeoPackageType.MEDIUMINT => StructField(name, IntegerType) - case GeoPackageType.POINT => StructField(name, GeometryUDT) - case GeoPackageType.LINESTRING => StructField(name, GeometryUDT) - case GeoPackageType.POLYGON => StructField(name, GeometryUDT) - case GeoPackageType.GEOMETRY => StructField(name, GeometryUDT) - case GeoPackageType.MULTIPOINT => StructField(name, GeometryUDT) - case GeoPackageType.MULTILINESTRING => StructField(name, GeometryUDT) - case GeoPackageType.MULTIPOLYGON => StructField(name, GeometryUDT) - case GeoPackageType.GEOMETRYCOLLECTION => StructField(name, GeometryUDT) + case GeoPackageType.POINT => StructField(name, GeometryUDT()) + case GeoPackageType.LINESTRING => StructField(name, GeometryUDT()) + case GeoPackageType.POLYGON => StructField(name, GeometryUDT()) + case GeoPackageType.GEOMETRY => StructField(name, GeometryUDT()) + case GeoPackageType.MULTIPOINT => StructField(name, GeometryUDT()) + case GeoPackageType.MULTILINESTRING => StructField(name, GeometryUDT()) + case GeoPackageType.MULTIPOLYGON => StructField(name, GeometryUDT()) + case GeoPackageType.GEOMETRYCOLLECTION => StructField(name, GeometryUDT()) case GeoPackageType.REAL => StructField(name, DoubleType) case GeoPackageType.BOOLEAN => StructField(name, BooleanType) case GeoPackageType.DATE => StructField(name, DateType) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/spider/SpiderTable.scala b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/spider/SpiderTable.scala index 81cb5bffecd..8acbef285b4 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/spider/SpiderTable.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/spider/SpiderTable.scala @@ -50,5 +50,5 @@ class SpiderTable( object SpiderTable { val SCHEMA: StructType = StructType( - Seq(StructField("id", LongType), StructField("geometry", GeometryUDT))) + Seq(StructField("id", LongType), StructField("geometry", GeometryUDT()))) } diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala index 96aab1287ee..732fb1272d7 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala @@ -143,7 +143,7 @@ object Adapter { val stringRow = extractUserData(geom) Row.fromSeq(stringRow) }) - var cols: Seq[StructField] = Seq(StructField("geometry", GeometryUDT)) + var cols: Seq[StructField] = Seq(StructField("geometry", GeometryUDT())) if (fieldNames != null && fieldNames.nonEmpty) { cols = cols ++ fieldNames.map(f => StructField(f, StringType)) } @@ -195,10 +195,10 @@ object Adapter { rightFieldnames = rightFieldNames) Row.fromSeq(stringRow) }) - var cols: Seq[StructField] = Seq(StructField("leftgeometry", GeometryUDT)) + var cols: Seq[StructField] = Seq(StructField("leftgeometry", GeometryUDT())) if (leftFieldnames != null && leftFieldnames.nonEmpty) cols = cols ++ leftFieldnames.map(fName => StructField(fName, StringType)) - cols = cols ++ Seq(StructField("rightgeometry", GeometryUDT)) + cols = cols ++ Seq(StructField("rightgeometry", GeometryUDT())) if (rightFieldNames != null && rightFieldNames.nonEmpty) cols = cols ++ rightFieldNames.map(fName => StructField(fName, StringType)) val schema = StructType(cols) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala index 04f4db8bf5a..8a5c2c19a0f 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala @@ -448,7 +448,7 @@ object GeoParquetFileFormat extends Logging { val fields = schema.fields.map { field => field.dataType match { case _: BinaryType if geoParquetMetaData.columns.contains(field.name) => - field.copy(dataType = GeometryUDT) + field.copy(dataType = GeometryUDT()) case _ => field } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSchemaConverter.scala b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSchemaConverter.scala index 749816ffcb5..34b9f44ec54 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSchemaConverter.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSchemaConverter.scala @@ -193,7 +193,7 @@ class GeoParquetToSparkSchemaConverter( case BINARY => originalType match { case UTF8 | ENUM | JSON => StringType - case null if isGeometryField(field.getName) => GeometryUDT + case null if isGeometryField(field.getName) => GeometryUDT() case null if assumeBinaryIsString => StringType case null => BinaryType case BSON => BinaryType diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala index d4d4fb8898c..969672a2c02 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala @@ -52,8 +52,12 @@ class GeographyUDT extends UserDefinedType[Geography] { } override def hashCode(): Int = userClass.hashCode() + + override def toString: String = "GeographyUDT" } case object GeographyUDT extends org.apache.spark.sql.sedona_sql.UDT.GeographyUDT - with scala.Serializable + with scala.Serializable { + def apply(): GeographyUDT = new GeographyUDT() +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala index 20ed85d52ec..ae1772e617a 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala @@ -53,8 +53,12 @@ class GeometryUDT extends UserDefinedType[Geometry] { } override def hashCode(): Int = userClass.hashCode() + + override def toString: String = "GeometryUDT" } case object GeometryUDT extends org.apache.spark.sql.sedona_sql.UDT.GeometryUDT - with scala.Serializable + with scala.Serializable { + def apply(): GeometryUDT = new GeometryUDT() +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala index a34b372211e..67fe763129c 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala @@ -56,6 +56,10 @@ class RasterUDT extends UserDefinedType[GridCoverage2D] { } override def hashCode(): Int = userClass.hashCode() + + override def toString: String = "RasterUDT" } -case object RasterUDT extends RasterUDT with Serializable +case object RasterUDT extends RasterUDT with Serializable { + def apply(): RasterUDT = new RasterUDT() +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala index d4b3d3efb1b..ce81764e7e7 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala @@ -177,7 +177,7 @@ private[apache] case class ST_GeomFromWKB(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -217,7 +217,7 @@ private[apache] case class ST_GeomFromEWKB(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -272,7 +272,7 @@ private[apache] case class ST_LineFromWKB(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = if (inputExpressions.length == 1) Seq(TypeCollection(StringType, BinaryType)) @@ -329,7 +329,7 @@ private[apache] case class ST_LinestringFromWKB(inputExpressions: Seq[Expression } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = if (inputExpressions.length == 1) Seq(TypeCollection(StringType, BinaryType)) @@ -386,7 +386,7 @@ private[apache] case class ST_PointFromWKB(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = if (inputExpressions.length == 1) Seq(TypeCollection(StringType, BinaryType)) @@ -435,7 +435,7 @@ private[apache] case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def children: Seq[Expression] = inputExpressions diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 565a9e99574..b5f85b89682 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.sedona_sql.expressions.implicits._ import org.apache.spark.sql.types._ import org.locationtech.jts.algorithm.MinimumBoundingCircle import org.locationtech.jts.geom._ +import org.locationtech.jts.geom.Geometry import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ import org.apache.spark.sql.sedona_sql.expressions.LibPostalUtils.{getExpanderFromConf, getParserFromConf} import org.apache.spark.unsafe.types.UTF8String @@ -393,9 +394,9 @@ private[apache] case class ST_IsValidDetail(children: Seq[Expression]) override def inputTypes: Seq[AbstractDataType] = { if (nArgs == 2) { - Seq(GeometryUDT, IntegerType) + Seq(GeometryUDT(), IntegerType) } else if (nArgs == 1) { - Seq(GeometryUDT) + Seq(GeometryUDT()) } else { throw new IllegalArgumentException(s"Invalid number of arguments: $nArgs") } @@ -439,7 +440,7 @@ private[apache] case class ST_IsValidDetail(children: Seq[Expression]) override def dataType: DataType = new StructType() .add("valid", BooleanType, nullable = false) .add("reason", StringType, nullable = true) - .add("location", GeometryUDT, nullable = true) + .add("location", GeometryUDT(), nullable = true) } private[apache] case class ST_IsValidTrajectory(inputExpressions: Seq[Expression]) @@ -735,7 +736,7 @@ private[apache] case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expres override def dataType: DataType = DataTypes.createStructType( Array( - DataTypes.createStructField("center", GeometryUDT, false), + DataTypes.createStructField("center", GeometryUDT(), false), DataTypes.createStructField("radius", DataTypes.DoubleType, false))) override def children: Seq[Expression] = inputExpressions @@ -1068,7 +1069,7 @@ private[apache] case class ST_SubDivideExplode(children: Seq[Expression]) override def elementSchema: StructType = { new StructType() - .add("geom", GeometryUDT, true) + .add("geom", GeometryUDT(), true) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -1187,8 +1188,8 @@ private[apache] case class ST_MaximumInscribedCircle(children: Seq[Expression]) override def nullable: Boolean = true override def dataType: DataType = new StructType() - .add("center", GeometryUDT, nullable = false) - .add("nearest", GeometryUDT, nullable = false) + .add("center", GeometryUDT(), nullable = false) + .add("nearest", GeometryUDT(), nullable = false) .add("radius", DoubleType, nullable = false) } @@ -1692,13 +1693,13 @@ private[apache] case class ST_GeneratePoints(inputExpressions: Seq[Expression], override def nullable: Boolean = true - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = { if (nArgs == 3) { - Seq(GeometryUDT, IntegerType, IntegerType) + Seq(GeometryUDT(), IntegerType, IntegerType) } else if (nArgs == 2) { - Seq(GeometryUDT, IntegerType) + Seq(GeometryUDT(), IntegerType) } else { throw new IllegalArgumentException(s"Invalid number of arguments: $nArgs") } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala index 0d805fbedee..eb6d670d3f2 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala @@ -36,7 +36,7 @@ private[apache] case class ST_DBSCAN(children: Seq[Expression]) StructType(Seq(StructField("isCore", BooleanType), StructField("cluster", LongType))) override def inputTypes: Seq[AbstractDataType] = - Seq(GeometryUDT, DoubleType, IntegerType, BooleanType) + Seq(GeometryUDT(), DoubleType, IntegerType, BooleanType) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) @@ -74,7 +74,7 @@ private[apache] case class ST_LocalOutlierFactor(children: Seq[Expression]) override def dataType: DataType = DoubleType override def inputTypes: Seq[AbstractDataType] = - Seq(GeometryUDT, IntegerType, BooleanType) + Seq(GeometryUDT(), IntegerType, BooleanType) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) @@ -139,7 +139,7 @@ private[apache] case class ST_BinaryDistanceBandColumn(children: Seq[Expression] Seq(StructField("neighbor", children(5).dataType), StructField("value", DoubleType)))) override def inputTypes: Seq[AbstractDataType] = - Seq(GeometryUDT, DoubleType, BooleanType, BooleanType, BooleanType, children(5).dataType) + Seq(GeometryUDT(), DoubleType, BooleanType, BooleanType, BooleanType, children(5).dataType) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) @@ -174,7 +174,7 @@ private[apache] case class ST_WeightedDistanceBandColumn(children: Seq[Expressio override def inputTypes: Seq[AbstractDataType] = Seq( - GeometryUDT, + GeometryUDT(), DoubleType, DoubleType, BooleanType, diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala index 2d3349d4ad7..970f03d01a7 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala @@ -37,8 +37,8 @@ object InferrableRasterTypes { def isRasterType(t: Type): Boolean = t =:= typeOf[GridCoverage2D] def isRasterArrayType(t: Type): Boolean = t =:= typeOf[Array[GridCoverage2D]] - val rasterUDT: UserDefinedType[_] = RasterUDT - val rasterUDTArray: ArrayType = DataTypes.createArrayType(RasterUDT) + val rasterUDT: UserDefinedType[_] = RasterUDT() + val rasterUDTArray: ArrayType = DataTypes.createArrayType(RasterUDT()) def rasterExtractor(expr: Expression)(input: InternalRow): Any = expr.toRaster(input) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala index 757948216c3..4b94084f95a 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala @@ -325,13 +325,13 @@ object InferredTypes { def inferSparkType(t: Type): DataType = { if (t =:= typeOf[Geometry]) { - GeometryUDT + GeometryUDT() } else if (t =:= typeOf[Array[Geometry]] || t =:= typeOf[java.util.List[Geometry]]) { - DataTypes.createArrayType(GeometryUDT) + DataTypes.createArrayType(GeometryUDT()) } else if (t =:= typeOf[Geography]) { - GeographyUDT + GeographyUDT() } else if (t =:= typeOf[Array[Geography]] || t =:= typeOf[java.util.List[Geography]]) { - DataTypes.createArrayType(GeographyUDT) + DataTypes.createArrayType(GeographyUDT()) } else if (InferredRasterExpression.isRasterType(t)) { InferredRasterExpression.rasterUDT } else if (InferredRasterExpression.isRasterArrayType(t)) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala index 9e69ed04382..fff9f6eeef5 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala @@ -40,7 +40,7 @@ abstract class ST_Predicate override def nullable: Boolean = children.exists(_.nullable) - override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, GeometryUDT) + override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT(), GeometryUDT()) override def dataType: DataType = BooleanType diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala index 3b430cc67df..4892fb5a8b3 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala @@ -82,7 +82,7 @@ private[apache] case class ST_Collect(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def children: Seq[Expression] = inputExpressions diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala index f72cd1e3946..13849f74f48 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala @@ -59,7 +59,7 @@ private[apache] case class RS_PixelAsPoints(inputExpressions: Seq[Expression]) override def dataType: DataType = ArrayType( new StructType() - .add("geom", GeometryUDT) + .add("geom", GeometryUDT()) .add("value", DoubleType) .add("x", IntegerType) .add("y", IntegerType)) @@ -88,7 +88,7 @@ private[apache] case class RS_PixelAsPoints(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT(), IntegerType) } private[apache] case class RS_PixelAsPolygon(inputExpressions: Seq[Expression]) @@ -107,7 +107,7 @@ private[apache] case class RS_PixelAsPolygons(inputExpressions: Seq[Expression]) override def dataType: DataType = ArrayType( new StructType() - .add("geom", GeometryUDT) + .add("geom", GeometryUDT()) .add("value", DoubleType) .add("x", IntegerType) .add("y", IntegerType)) @@ -137,7 +137,7 @@ private[apache] case class RS_PixelAsPolygons(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT(), IntegerType) } private[apache] case class RS_PixelAsCentroid(inputExpressions: Seq[Expression]) @@ -156,7 +156,7 @@ private[apache] case class RS_PixelAsCentroids(inputExpressions: Seq[Expression] override def dataType: DataType = ArrayType( new StructType() - .add("geom", GeometryUDT) + .add("geom", GeometryUDT()) .add("value", DoubleType) .add("x", IntegerType) .add("y", IntegerType)) @@ -186,7 +186,7 @@ private[apache] case class RS_PixelAsCentroids(inputExpressions: Seq[Expression] copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT(), IntegerType) } private[apache] case class RS_Values(inputExpressions: Seq[Expression]) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala index 77d4e700b04..2523e893437 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala @@ -147,7 +147,7 @@ private[apache] case class RS_TileExplode(children: Seq[Expression]) new StructType() .add("x", IntegerType, nullable = false) .add("y", IntegerType, nullable = false) - .add("tile", RasterUDT, nullable = false) + .add("tile", RasterUDT(), nullable = false) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterFunctions.scala index b1fa43b5292..67d439b7d71 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterFunctions.scala @@ -83,7 +83,7 @@ private[apache] case class RS_Metadata(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT()) } private[apache] case class RS_SummaryStatsAll(inputExpressions: Seq[Expression]) @@ -139,13 +139,13 @@ private[apache] case class RS_SummaryStatsAll(inputExpressions: Seq[Expression]) override def inputTypes: Seq[AbstractDataType] = { if (inputExpressions.length == 1) { - Seq(RasterUDT) + Seq(RasterUDT()) } else if (inputExpressions.length == 2) { - Seq(RasterUDT, IntegerType) + Seq(RasterUDT(), IntegerType) } else if (inputExpressions.length == 3) { - Seq(RasterUDT, IntegerType, BooleanType) + Seq(RasterUDT(), IntegerType, BooleanType) } else { - Seq(RasterUDT) + Seq(RasterUDT()) } } } @@ -220,17 +220,17 @@ private[apache] case class RS_ZonalStatsAll(inputExpressions: Seq[Expression]) override def inputTypes: Seq[AbstractDataType] = { if (inputExpressions.length == 2) { - Seq(RasterUDT, GeometryUDT) + Seq(RasterUDT(), GeometryUDT()) } else if (inputExpressions.length == 3) { - Seq(RasterUDT, GeometryUDT, IntegerType) + Seq(RasterUDT(), GeometryUDT(), IntegerType) } else if (inputExpressions.length == 4) { - Seq(RasterUDT, GeometryUDT, IntegerType, BooleanType) + Seq(RasterUDT(), GeometryUDT(), IntegerType, BooleanType) } else if (inputExpressions.length == 5) { - Seq(RasterUDT, GeometryUDT, IntegerType, BooleanType, BooleanType) + Seq(RasterUDT(), GeometryUDT(), IntegerType, BooleanType, BooleanType) } else if (inputExpressions.length >= 6) { - Seq(RasterUDT, GeometryUDT, IntegerType, BooleanType, BooleanType, BooleanType) + Seq(RasterUDT(), GeometryUDT(), IntegerType, BooleanType, BooleanType, BooleanType) } else { - Seq(RasterUDT, GeometryUDT) + Seq(RasterUDT(), GeometryUDT()) } } } @@ -276,5 +276,5 @@ private[apache] case class RS_GeoTransform(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT()) } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala index e247a2825b5..35aaa276c04 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala @@ -47,9 +47,9 @@ abstract class RS_Predicate val leftType = inputExpressions.head.dataType val rightType = inputExpressions(1).dataType (leftType, rightType) match { - case (_: RasterUDT, _: GeometryUDT) => Seq(RasterUDT, GeometryUDT) - case (_: GeometryUDT, _: RasterUDT) => Seq(GeometryUDT, RasterUDT) - case (_: RasterUDT, _: RasterUDT) => Seq(RasterUDT, RasterUDT) + case (_: RasterUDT, _: GeometryUDT) => Seq(RasterUDT(), GeometryUDT()) + case (_: GeometryUDT, _: RasterUDT) => Seq(GeometryUDT(), RasterUDT()) + case (_: RasterUDT, _: RasterUDT) => Seq(RasterUDT(), RasterUDT()) case _ => throw new IllegalArgumentException(s"Unsupported input types: $leftType, $rightType") } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONFileFormat.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONFileFormat.scala index ca588c33bea..67dafb4e45b 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONFileFormat.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONFileFormat.scala @@ -69,7 +69,7 @@ class GeoJSONFileFormat extends TextBasedFileFormat with DataSourceRegister { fullSchemaOption.map { fullSchema => // Replace 'geometry' field type with GeometryUDT - val newFields = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT) + val newFields = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT()) StructType(newFields) } } @@ -131,7 +131,7 @@ class GeoJSONFileFormat extends TextBasedFileFormat with DataSourceRegister { geometryColumnName.split('.'), resolver = SQLConf.get.resolver) match { case Some(StructField(_, dataType, _, _)) => - if (!dataType.acceptsType(GeometryUDT)) { + if (!dataType.acceptsType(GeometryUDT())) { throw new IllegalArgumentException(s"$geometryColumnName is not a geometry column") } case None => diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala index 154adf8eca4..0aad6e3d48b 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala @@ -75,7 +75,7 @@ class StacDataSource() extends TableProvider with DataSourceRegister { // Check if the schema is already cached val fullSchema = schemaCache.computeIfAbsent(optsMap, _ => inferStacSchema(optsMap)) - val updatedGeometrySchema = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT) + val updatedGeometrySchema = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT()) updatePropertiesPromotedSchema(updatedGeometrySchema) } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala index ca5f32663ca..1da2af3dd62 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala @@ -67,7 +67,7 @@ class StacTable( override def schema(): StructType = { // Check if the schema is already cached val fullSchema = schemaCache.computeIfAbsent(opts, _ => inferStacSchema(opts)) - val updatedGeometrySchema = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT) + val updatedGeometrySchema = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT()) updatePropertiesPromotedSchema(updatedGeometrySchema) } diff --git a/spark/common/src/test/java/org/apache/sedona/core/spatialPartitioning/GenericUniquePartitionerTest.java b/spark/common/src/test/java/org/apache/sedona/core/spatialPartitioning/GenericUniquePartitionerTest.java index 1df270c0a01..02b9069a13e 100644 --- a/spark/common/src/test/java/org/apache/sedona/core/spatialPartitioning/GenericUniquePartitionerTest.java +++ b/spark/common/src/test/java/org/apache/sedona/core/spatialPartitioning/GenericUniquePartitionerTest.java @@ -22,7 +22,11 @@ import java.util.ArrayList; import java.util.Iterator; -import org.apache.commons.collections.IteratorUtils; +import java.util.List; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.junit.Test; import org.locationtech.jts.geom.Envelope; import org.locationtech.jts.geom.Geometry; @@ -52,10 +56,16 @@ public void testUniquePartition() throws Exception { partitioner.placeObject(factory.toGeometry(definitelyHasMultiplePartitions)); // Because the geometry is not completely contained by any of the partitions, // it also gets placed in the overflow partition (hence 5, not 4) - assertEquals(5, IteratorUtils.toList(placedWithDuplicates).size()); + assertEquals(5, iteratorToList(placedWithDuplicates).size()); Iterator> placedWithoutDuplicates = uniquePartitioner.placeObject(factory.toGeometry(definitelyHasMultiplePartitions)); - assertEquals(1, IteratorUtils.toList(placedWithoutDuplicates).size()); + assertEquals(1, iteratorToList(placedWithoutDuplicates).size()); + } + + private List iteratorToList(Iterator iterator) { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize(iterator, Spliterator.ORDERED), false) + .collect(Collectors.toList()); } } diff --git a/spark/common/src/test/java/org/apache/sedona/sql/adapterTestJava.java b/spark/common/src/test/java/org/apache/sedona/sql/adapterTestJava.java index 434a2783b63..b4405730c19 100644 --- a/spark/common/src/test/java/org/apache/sedona/sql/adapterTestJava.java +++ b/spark/common/src/test/java/org/apache/sedona/sql/adapterTestJava.java @@ -55,7 +55,11 @@ public class adapterTestJava implements Serializable { public static void onceExecutedBeforeAll() { sparkSession = SedonaContext.create( - SedonaContext.builder().master("local[*]").appName("adapterTestJava").getOrCreate()); + SedonaContext.builder() + .master("local[*]") + .appName("adapterTestJava") + .config("spark.sql.geospatial.enabled", "false") + .getOrCreate()); Logger.getLogger("org").setLevel(Level.WARN); Logger.getLogger("akka").setLevel(Level.WARN); } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala index 50696337f7f..4ca17598d30 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala @@ -550,7 +550,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { val emptyRdd = sparkSession.sparkContext.emptyRDD[Row] val emptyDf = sparkSession.createDataFrame( emptyRdd, - StructType(Seq(StructField("id", IntegerType), StructField("geom", GeometryUDT)))) + StructType(Seq(StructField("id", IntegerType), StructField("geom", GeometryUDT())))) emptyDf.createOrReplaceTempView("EMPTYTABLE") df1.createOrReplaceTempView("df1") diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/OsmReaderTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/OsmReaderTest.scala index 8c61ae0785d..817a5db0b91 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/OsmReaderTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/OsmReaderTest.scala @@ -176,28 +176,28 @@ class OsmReaderTest extends TestBaseScala with Matchers { "smoothness" -> "excellent")) // make sure the nodes match with refs - val nodes = osmData.where("kind == 'node'") - val ways = osmData.where("kind == 'way'") - val relations = osmData.where("kind == 'relation'") + val nodes = osmData.where("kind == 'node'").as("n") + val ways = osmData.where("kind == 'way'").as("w") + val relations = osmData.where("kind == 'relation'").as("rel") ways .selectExpr("explode(refs) AS ref") - .alias("w") - .join(nodes, col("w.ref") === nodes("id")) + .alias("w2") + .join(nodes, col("w2.ref") === col("n.id")) .count() shouldEqual (47812) ways .selectExpr("explode(refs) AS ref", "id") - .alias("w") - .join(nodes, col("w.ref") === nodes("id")) - .groupBy("w.id") + .alias("w2") + .join(nodes, col("w2.ref") === col("n.id")) + .groupBy("w2.id") .count() .count() shouldEqual (ways.count()) relations .selectExpr("explode(refs) AS ref", "id") .alias("r") - .join(nodes, col("r.ref") === nodes("id")) + .join(nodes, col("r.ref") === col("n.id")) .groupBy("r.id") .count() .count() shouldEqual (162) @@ -205,7 +205,7 @@ class OsmReaderTest extends TestBaseScala with Matchers { relations .selectExpr("explode(refs) AS ref", "id") .alias("r") - .join(ways, col("r.ref") === ways("id")) + .join(ways, col("r.ref") === col("w.id")) .groupBy("r.id") .count() .count() shouldEqual (261) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala index 69b1641cc40..749a1be48a2 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala @@ -144,12 +144,12 @@ class PreserveSRIDSuite extends TestBaseScala with TableDrivenPropertyChecks { val schema = StructType( Seq( - StructField("geom1", GeometryUDT), - StructField("geom2", GeometryUDT), - StructField("geom3", GeometryUDT), - StructField("geom4", GeometryUDT), - StructField("geom5", GeometryUDT), - StructField("geom6", GeometryUDT))) + StructField("geom1", GeometryUDT()), + StructField("geom2", GeometryUDT()), + StructField("geom3", GeometryUDT()), + StructField("geom4", GeometryUDT()), + StructField("geom5", GeometryUDT()), + StructField("geom6", GeometryUDT()))) val geom1 = Constructors.geomFromWKT("POLYGON ((0 0, 1 0, 0.5 0.5, 1 1, 0 1, 0 0))", 1000) val geom2 = Constructors.geomFromWKT("MULTILINESTRING ((0 0, 0 1), (0 1, 1 1), (1 1, 1 0))", 1000) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala index 9f4fd30f95b..56bd3bf23f1 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala @@ -308,7 +308,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { val emptyRdd = sparkSession.sparkContext.emptyRDD[Row] val emptyDf = sparkSession.createDataFrame( emptyRdd, - StructType(Seq(StructField("id", IntegerType), StructField("geom", GeometryUDT)))) + StructType(Seq(StructField("id", IntegerType), StructField("geom", GeometryUDT())))) df1.createOrReplaceTempView("df1") df2.createOrReplaceTempView("df2") emptyDf.createOrReplaceTempView("dfEmpty") diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 2f57d310ae6..0068da6060f 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -48,7 +48,10 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { def defaultSparkConfig: Map[String, String] = Map( "spark.sql.warehouse.dir" -> (System.getProperty("user.dir") + "/target/"), "sedona.join.autoBroadcastJoinThreshold" -> "-1", - "spark.kryoserializer.buffer.max" -> "64m") + "spark.kryoserializer.buffer.max" -> "64m", + // Disable Spark 4.1+ native geospatial functions that shadow Sedona's ST functions. + // This config is ignored on Spark versions that don't have it. + "spark.sql.geospatial.enabled" -> "false") // Method to be overridden by subclasses to provide additional configurations def sparkConfig: Map[String, String] = defaultSparkConfig diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala index 0d267e62567..2ca9f359543 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala @@ -319,7 +319,7 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { // Tweak the column names to camelCase to ensure it also renames val schema = StructType( Array( - StructField("leftGeometry", GeometryUDT, nullable = true), + StructField("leftGeometry", GeometryUDT(), nullable = true), StructField("exampleText", StringType, nullable = true), StructField("exampleDouble", DoubleType, nullable = true), StructField("exampleInt", IntegerType, nullable = true))) @@ -408,7 +408,7 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { // Tweak the column names to camelCase to ensure it also renames val schema = StructType( Array( - StructField("leftGeometry", GeometryUDT, nullable = true), + StructField("leftGeometry", GeometryUDT(), nullable = true), StructField("exampleText", StringType, nullable = true), StructField("exampleFloat", FloatType, nullable = true), StructField("exampleDouble", DoubleType, nullable = true), @@ -424,7 +424,7 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { StructField("structText", StringType, nullable = true), StructField("structInt", IntegerType, nullable = true), StructField("structBool", BooleanType, nullable = true)))), - StructField("rightGeometry", GeometryUDT, nullable = true), + StructField("rightGeometry", GeometryUDT(), nullable = true), // We have to include a column for right user data (even though there is none) // since there is no way to distinguish between no data and nullable data StructField("rightUserData", StringType, nullable = true))) @@ -493,11 +493,11 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { // Convert to DataFrame val schema = StructType( Array( - StructField("leftgeometry", GeometryUDT, nullable = true), + StructField("leftgeometry", GeometryUDT(), nullable = true), StructField("exampletext", StringType, nullable = true), StructField("exampledouble", StringType, nullable = true), StructField("exampleint", StringType, nullable = true), - StructField("rightgeometry", GeometryUDT, nullable = true), + StructField("rightgeometry", GeometryUDT(), nullable = true), StructField("userdata", StringType, nullable = true))) val joinResultDf = Adapter.toDf(joinResultPairRDD, schema, sparkSession) val resultWithoutSchema = Adapter.toDf( diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala index c3ef8dd89ef..7628c9a6daa 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala @@ -169,7 +169,7 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val geomTypes = (geo \ "columns" \ "geometry" \ "geometry_types").extract[Seq[String]] assert(geomTypes.nonEmpty) val sparkSqlRowMetadata = metadata.get(ParquetReadSupport.SPARK_METADATA_KEY) - assert(!sparkSqlRowMetadata.contains("GeometryUDT")) + assert(!sparkSqlRowMetadata.contains("GeometryUDT()")) } } it("GEOPARQUET Test example-1.1.0.parquet") { @@ -206,8 +206,8 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("g0", GeometryUDT, nullable = false), - StructField("g1", GeometryUDT, nullable = false))) + StructField("g0", GeometryUDT(), nullable = false), + StructField("g1", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) val geoParquetSavePath = geoparquetoutputlocation + "/multi_geoms.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) @@ -241,7 +241,7 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("g", GeometryUDT, nullable = false))) + StructField("g", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/empty.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) @@ -262,7 +262,7 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column", GeometryUDT, nullable = false))) + StructField("geom_column", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/snake_case_column_name.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) @@ -277,7 +277,7 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geomColumn", GeometryUDT, nullable = false))) + StructField("geomColumn", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/camel_case_column_name.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) @@ -400,8 +400,8 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("g0", GeometryUDT, nullable = false), - StructField("g1", GeometryUDT, nullable = false))) + StructField("g0", GeometryUDT(), nullable = false), + StructField("g1", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) val projjson0 = diff --git a/spark/pom.xml b/spark/pom.xml index fe35e36441f..f64cc6c74ae 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -50,6 +50,7 @@ spark-3.5 spark-4.0 + spark-4.1 diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala index 12238870be1..efd6098d518 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -141,9 +141,9 @@ object ShapefileUtils { throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") } StructType( - Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) case _ => - StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) } } diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala index 2bdd92bd64a..b56ed11c875 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -32,7 +32,7 @@ class SedonaSqlAstBuilder extends SparkSqlAstBuilder { */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { ctx.getText.toUpperCase() match { - case "GEOMETRY" => GeometryUDT + case "GEOMETRY" => GeometryUDT() case _ => super.visitPrimitiveDataType(ctx) } } diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala index 1e4b071361b..166d8c48dba 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -44,7 +44,7 @@ class GeoPackageReaderTest extends TestBaseScala with Matchers { val expectedFeatureSchema = StructType( Seq( StructField("id", IntegerType, true), - StructField("geometry", GeometryUDT, true), + StructField("geometry", GeometryUDT(), true), StructField("text", StringType, true), StructField("real", DoubleType, true), StructField("boolean", BooleanType, true), diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala index 421890c7001..01306c1b452 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -120,8 +120,8 @@ class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column_1", GeometryUDT, nullable = false), - StructField("geomColumn2", GeometryUDT, nullable = false))) + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index bb53131475f..c2e88c469be 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -698,7 +698,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { Seq( StructField("osm_id", StringType), StructField("code2", LongType), - StructField("geometry", GeometryUDT))) + StructField("geometry", GeometryUDT()))) val shapefileDf = sparkSession.read .format("shapefile") .schema(customSchema) diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala index 04ac3bdff9b..fd6d1e83827 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -141,9 +141,9 @@ object ShapefileUtils { throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") } StructType( - Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) case _ => - StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala index 2bdd92bd64a..b56ed11c875 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -32,7 +32,7 @@ class SedonaSqlAstBuilder extends SparkSqlAstBuilder { */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { ctx.getText.toUpperCase() match { - case "GEOMETRY" => GeometryUDT + case "GEOMETRY" => GeometryUDT() case _ => super.visitPrimitiveDataType(ctx) } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala index 6d9f41bf4e3..66fa147bc58 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -44,7 +44,7 @@ class GeoPackageReaderTest extends TestBaseScala with Matchers { val expectedFeatureSchema = StructType( Seq( StructField("id", IntegerType, true), - StructField("geometry", GeometryUDT, true), + StructField("geometry", GeometryUDT(), true), StructField("text", StringType, true), StructField("real", DoubleType, true), StructField("boolean", BooleanType, true), diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala index 421890c7001..01306c1b452 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -120,8 +120,8 @@ class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column_1", GeometryUDT, nullable = false), - StructField("geomColumn2", GeometryUDT, nullable = false))) + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index a0cadc89787..275bc3282f9 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -710,7 +710,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { Seq( StructField("osm_id", StringType), StructField("code2", LongType), - StructField("geometry", GeometryUDT))) + StructField("geometry", GeometryUDT()))) val shapefileDf = sparkSession.read .format("shapefile") .schema(customSchema) diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index c0a2d8f260d..62733169288 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -116,7 +116,7 @@ object ScalarUDF { pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, accumulator = null), - dataType = GeometryUDT, + dataType = GeometryUDT(), pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, udfDeterministic = true) } diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala index 04ac3bdff9b..fd6d1e83827 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -141,9 +141,9 @@ object ShapefileUtils { throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") } StructType( - Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) case _ => - StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) } } diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala index 2bdd92bd64a..b56ed11c875 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -32,7 +32,7 @@ class SedonaSqlAstBuilder extends SparkSqlAstBuilder { */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { ctx.getText.toUpperCase() match { - case "GEOMETRY" => GeometryUDT + case "GEOMETRY" => GeometryUDT() case _ => super.visitPrimitiveDataType(ctx) } } diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala index 6d9f41bf4e3..66fa147bc58 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -44,7 +44,7 @@ class GeoPackageReaderTest extends TestBaseScala with Matchers { val expectedFeatureSchema = StructType( Seq( StructField("id", IntegerType, true), - StructField("geometry", GeometryUDT, true), + StructField("geometry", GeometryUDT(), true), StructField("text", StringType, true), StructField("real", DoubleType, true), StructField("boolean", BooleanType, true), diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala index 421890c7001..01306c1b452 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -120,8 +120,8 @@ class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column_1", GeometryUDT, nullable = false), - StructField("geomColumn2", GeometryUDT, nullable = false))) + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index a0cadc89787..275bc3282f9 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -710,7 +710,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { Seq( StructField("osm_id", StringType), StructField("code2", LongType), - StructField("geometry", GeometryUDT))) + StructField("geometry", GeometryUDT()))) val shapefileDf = sparkSession.read .format("shapefile") .schema(customSchema) diff --git a/spark/spark-4.0/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-4.0/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index c0a2d8f260d..62733169288 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -116,7 +116,7 @@ object ScalarUDF { pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, accumulator = null), - dataType = GeometryUDT, + dataType = GeometryUDT(), pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, udfDeterministic = true) } diff --git a/spark/spark-4.1/.gitignore b/spark/spark-4.1/.gitignore new file mode 100644 index 00000000000..f34cc0c65b4 --- /dev/null +++ b/spark/spark-4.1/.gitignore @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +/target/ +/.settings/ +/.classpath +/.project +/dependency-reduced-pom.xml +/doc/ +/.idea/ +*.iml +/latest/ +/spark-warehouse/ +/metastore_db/ +*.log diff --git a/spark/spark-4.1/pom.xml b/spark/spark-4.1/pom.xml new file mode 100644 index 00000000000..e4d4b51f0e9 --- /dev/null +++ b/spark/spark-4.1/pom.xml @@ -0,0 +1,185 @@ + + + + 4.0.0 + + org.apache.sedona + sedona-spark-parent-${spark.compat.version}_${scala.compat.version} + 1.9.0-SNAPSHOT + ../pom.xml + + sedona-spark-4.1_${scala.compat.version} + + ${project.groupId}:${project.artifactId} + A cluster computing system for processing large-scale spatial data: SQL API for Spark 4.1. + http://sedona.apache.org/ + jar + + + false + + + + + org.apache.sedona + sedona-common + ${project.version} + + + com.fasterxml.jackson.core + * + + + it.geosolutions.jaiext.jiffle + * + + + org.codehaus.janino + * + + + + + org.apache.sedona + sedona-spark-common-${spark.compat.version}_${scala.compat.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.compat.version} + + + org.apache.spark + spark-sql_${scala.compat.version} + + + org.apache.hadoop + hadoop-client + + + org.apache.logging.log4j + log4j-1.2-api + + + org.geotools + gt-main + + + org.geotools + gt-referencing + + + org.geotools + gt-epsg-hsql + + + org.geotools + gt-geotiff + + + org.geotools + gt-coverage + + + org.geotools + gt-arcgrid + + + org.locationtech.jts + jts-core + + + org.wololo + jts2geojson + + + com.fasterxml.jackson.core + * + + + + + org.scala-lang + scala-library + + + org.scala-lang.modules + scala-collection-compat_${scala.compat.version} + + + org.scalatest + scalatest_${scala.compat.version} + + + org.mockito + mockito-inline + + + org.testcontainers + testcontainers + 2.0.2 + test + + + org.testcontainers + testcontainers-minio + 2.0.2 + test + + + io.minio + minio + + + com.squareup.okhttp3 + okhttp + + + org.apache.hadoop + hadoop-aws + ${hadoop.version} + test + + + org.apache.hadoop + hadoop-client-api + ${hadoop.version} + test + + + + src/main/scala + + + net.alchim31.maven + scala-maven-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + + diff --git a/spark/spark-4.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-4.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 00000000000..ae1de3d8bd2 --- /dev/null +++ b/spark/spark-4.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource +org.apache.sedona.sql.datasources.shapefile.ShapefileDataSource +org.apache.sedona.sql.datasources.geopackage.GeoPackageDataSource diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageDataSource.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageDataSource.scala new file mode 100644 index 00000000000..11f2db38e84 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageDataSource.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.Path +import org.apache.sedona.sql.datasources.geopackage.model.GeoPackageOptions +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.jdk.CollectionConverters._ +import scala.util.Try + +class GeoPackageDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def fallbackFileFormat: Class[_ <: FileFormat] = { + null + } + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + GeoPackageTable( + "", + sparkSession, + options, + getPaths(options), + None, + fallbackFileFormat, + getLoadOptions(options)) + } + + private def getLoadOptions(options: CaseInsensitiveStringMap): GeoPackageOptions = { + val path = options.get("path") + if (path.isEmpty) { + throw new IllegalArgumentException("GeoPackage path is not specified") + } + + val showMetadata = options.getBoolean("showMetadata", false) + val maybeTableName = options.get("tableName") + + if (!showMetadata && maybeTableName == null) { + throw new IllegalArgumentException("Table name is not specified") + } + + val tableName = if (showMetadata) { + "gpkg_contents" + } else { + maybeTableName + } + + GeoPackageOptions(tableName = tableName, showMetadata = showMetadata) + } + + override def shortName(): String = "geopackage" +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReader.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReader.scala new file mode 100644 index 00000000000..4e59163922d --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReader.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.Path +import org.apache.sedona.sql.datasources.geopackage.connection.{FileSystemUtils, GeoPackageConnectionManager} +import org.apache.sedona.sql.datasources.geopackage.model.TableType.{FEATURES, METADATA, TILES, UNKNOWN} +import org.apache.sedona.sql.datasources.geopackage.model.{GeoPackageReadOptions, PartitionOptions, TileRowMetadata} +import org.apache.sedona.sql.datasources.geopackage.transform.ValuesMapper +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.util.SerializableConfiguration + +import java.io.File +import java.sql.ResultSet + +case class GeoPackagePartitionReader( + var rs: ResultSet, + options: GeoPackageReadOptions, + broadcastedConf: Broadcast[SerializableConfiguration], + var currentTempFile: File, + copying: Boolean = false) + extends PartitionReader[InternalRow] { + + private var values: Seq[Any] = Seq.empty + private var currentFile = options.currentFile + private val partitionedFiles = options.partitionedFiles + + override def next(): Boolean = { + if (rs.next()) { + values = ValuesMapper.mapValues(adjustPartitionOptions, rs) + return true + } + + partitionedFiles.remove(currentFile) + + if (partitionedFiles.isEmpty) { + return false + } + + rs.close() + + currentFile = partitionedFiles.head + val (tempFile, _) = FileSystemUtils.copyToLocal( + options = broadcastedConf.value.value, + file = new Path(currentFile.filePath.toString())) + + if (copying) { + currentTempFile.deleteOnExit() + } + + currentTempFile = tempFile + + rs = GeoPackageConnectionManager.getTableCursor(currentTempFile.getPath, options.tableName) + + if (!rs.next()) { + return false + } + + values = ValuesMapper.mapValues(adjustPartitionOptions, rs) + + true + } + + private def adjustPartitionOptions: PartitionOptions = { + options.partitionOptions.tableType match { + case FEATURES | METADATA => options.partitionOptions + case TILES => + val tileRowMetadata = TileRowMetadata( + zoomLevel = rs.getInt("zoom_level"), + tileColumn = rs.getInt("tile_column"), + tileRow = rs.getInt("tile_row")) + + options.partitionOptions.withTileRowMetadata(tileRowMetadata) + case UNKNOWN => options.partitionOptions + } + + } + + override def get(): InternalRow = { + InternalRow.fromSeq(values) + } + + override def close(): Unit = { + rs.close() + if (copying) { + options.tempFile.delete() + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReaderFactory.scala new file mode 100644 index 00000000000..b1d38996b04 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReaderFactory.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.Path +import org.apache.sedona.sql.datasources.geopackage.connection.{FileSystemUtils, GeoPackageConnectionManager} +import org.apache.sedona.sql.datasources.geopackage.model.TableType.TILES +import org.apache.sedona.sql.datasources.geopackage.model.{GeoPackageOptions, GeoPackageReadOptions, PartitionOptions, TableType} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class GeoPackagePartitionReaderFactory( + sparkSession: SparkSession, + broadcastedConf: Broadcast[SerializableConfiguration], + loadOptions: GeoPackageOptions, + dataSchema: StructType) + extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val partitionFiles = partition match { + case filePartition: FilePartition => filePartition.files + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + val (tempFile, copied) = FileSystemUtils.copyToLocal( + options = broadcastedConf.value.value, + file = new Path(partitionFiles.head.filePath.toString())) + + val tableType = if (loadOptions.showMetadata) { + TableType.METADATA + } else { + GeoPackageConnectionManager.findFeatureMetadata(tempFile.getPath, loadOptions.tableName) + } + + val rs = + GeoPackageConnectionManager.getTableCursor(tempFile.getAbsolutePath, loadOptions.tableName) + + val schema = GeoPackageConnectionManager.getSchema(tempFile.getPath, loadOptions.tableName) + + if (StructType(schema.map(_.toStructField(tableType))) != dataSchema) { + throw new IllegalArgumentException( + s"Schema mismatch: expected $dataSchema, got ${StructType(schema.map(_.toStructField(tableType)))}") + } + + val tileMetadata = tableType match { + case TILES => + Some( + GeoPackageConnectionManager.findTilesMetadata(tempFile.getPath, loadOptions.tableName)) + case _ => None + } + + GeoPackagePartitionReader( + rs = rs, + options = GeoPackageReadOptions( + tableName = loadOptions.tableName, + tempFile = tempFile, + partitionOptions = + PartitionOptions(tableType = tableType, columns = schema, tile = tileMetadata), + partitionedFiles = scala.collection.mutable.HashSet(partitionFiles: _*), + currentFile = partitionFiles.head), + broadcastedConf = broadcastedConf, + currentTempFile = tempFile, + copying = copied) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScan.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScan.scala new file mode 100644 index 00000000000..768afd79173 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScan.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.sedona.sql.datasources.geopackage.model.GeoPackageOptions +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.jdk.CollectionConverters._ + +case class GeoPackageScan( + dataSchema: StructType, + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + loadOptions: GeoPackageOptions) + extends FileScan { + + override def partitionFilters: Seq[Expression] = { + Seq.empty + } + + override def dataFilters: Seq[Expression] = { + Seq.empty + } + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + GeoPackagePartitionReaderFactory(sparkSession, broadcastedConf, loadOptions, dataSchema) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScanBuilder.scala new file mode 100644 index 00000000000..829bd9c2201 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScanBuilder.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.sedona.sql.datasources.geopackage.model.GeoPackageOptions +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.{InMemoryFileIndex, PartitioningAwareFileIndex} +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import scala.jdk.CollectionConverters._ + +class GeoPackageScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + options: CaseInsensitiveStringMap, + loadOptions: GeoPackageOptions, + userDefinedSchema: Option[StructType] = None) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + val fileIndexAdjusted = + if (loadOptions.showMetadata) + new InMemoryFileIndex( + sparkSession, + fileIndex.inputFiles.slice(0, 1).map(new org.apache.hadoop.fs.Path(_)), + options.asCaseSensitiveMap.asScala.toMap, + userDefinedSchema) + else fileIndex + + GeoPackageScan( + dataSchema, + sparkSession, + fileIndexAdjusted, + dataSchema, + readPartitionSchema(), + options, + loadOptions) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageTable.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageTable.scala new file mode 100644 index 00000000000..85dec8427e4 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageTable.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.sql.datasources.geopackage.connection.{FileSystemUtils, GeoPackageConnectionManager} +import org.apache.sedona.sql.datasources.geopackage.model.{GeoPackageOptions, MetadataSchema, TableType} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.jdk.CollectionConverters._ + +case class GeoPackageTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat], + loadOptions: GeoPackageOptions) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (loadOptions.showMetadata) { + return MetadataSchema.schema + } + + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + + val (tempFile, copied) = + FileSystemUtils.copyToLocal(serializableConf.value, files.head.getPath) + + if (copied) { + tempFile.deleteOnExit() + } + + val tableType = if (loadOptions.showMetadata) { + TableType.METADATA + } else { + GeoPackageConnectionManager.findFeatureMetadata(tempFile.getPath, loadOptions.tableName) + } + + Some( + StructType( + GeoPackageConnectionManager + .getSchema(tempFile.getPath, loadOptions.tableName) + .map(field => field.toStructField(tableType)))) + } + + override def formatName: String = { + "GeoPackage" + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new GeoPackageScanBuilder( + sparkSession, + fileIndex, + schema, + options, + loadOptions, + userSpecifiedSchema) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + null + } + +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala new file mode 100644 index 00000000000..7cd6d03a6d9 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +/** + * A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the + * following components of shapefiles: + * + *
  • .shp: the main file
  • .dbf: (optional) the attribute file
  • .shx: (optional) the + * index file
  • .cpg: (optional) the code page file
  • .prj: (optional) the projection file + *
+ * + *

The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala new file mode 100644 index 00000000000..306b1df4f6c --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.Partition +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.execution.datasources.PartitionedFile + +case class ShapefilePartition(index: Int, files: Array[PartitionedFile]) + extends Partition + with InputPartition diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 00000000000..301d63296fb --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = file.filePath.toPath.getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> file.filePath.toPath + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 00000000000..5a28af6d66f --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: ShapefilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 00000000000..ebc02fae85a --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 00000000000..afdface8942 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = partitionedFile.filePath.toPath + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(ShapefilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 00000000000..e5135e381df --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 00000000000..7db6bb8d1f4 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 00000000000..fd6d1e83827 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis.Resolver +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromString(UTF8String.fromBytes(bytes)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala new file mode 100644 index 00000000000..b56ed11c875 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.parser + +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.execution.SparkSqlAstBuilder +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.DataType + +class SedonaSqlAstBuilder extends SparkSqlAstBuilder { + + /** + * Override the method to handle the geometry data type + * @param ctx + * @return + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { + ctx.getText.toUpperCase() match { + case "GEOMETRY" => GeometryUDT() + case _ => super.visitPrimitiveDataType(ctx) + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala new file mode 100644 index 00000000000..cefd1487a47 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.parser + +import org.apache.spark.sql.catalyst.parser.{ParameterContext, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkSqlParser + +class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { + + // The parser builder for the Sedona SQL AST + val parserBuilder = new SedonaSqlAstBuilder + + private def sedonaFallback(sqlText: String): LogicalPlan = + parse(sqlText) { parser => + parserBuilder.visit(parser.singleStatement()) + }.asInstanceOf[LogicalPlan] + + /** + * Parse the SQL text with parameters and return the logical plan. In Spark 4.1+, SparkSqlParser + * overrides parsePlanWithParameters to bypass parsePlan, so we must override this method to + * intercept the parse flow. + * + * This method first attempts to use the delegate parser. If the delegate parser fails (throws + * an exception), it falls back to using the Sedona SQL parser. + */ + override def parsePlanWithParameters( + sqlText: String, + paramContext: ParameterContext): LogicalPlan = + try { + delegate.parsePlanWithParameters(sqlText, paramContext) + } catch { + case _: Exception => + sedonaFallback(sqlText) + } + + /** + * Parse the SQL text and return the logical plan. Note: in Spark 4.1+, SparkSession.sql() calls + * parsePlanWithParameters (overridden above), which no longer delegates to parsePlan. This + * override is kept as a defensive measure in case any third-party code or future Spark + * internals call parsePlan directly on the parser instance. + */ + override def parsePlan(sqlText: String): LogicalPlan = + try { + delegate.parsePlan(sqlText) + } catch { + case _: Exception => + sedonaFallback(sqlText) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala new file mode 100644 index 00000000000..43e1ababb7d --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Data source for reading GeoParquet metadata. This could be accessed using the `spark.read` + * interface: + * {{{ + * val df = spark.read.format("geoparquet.metadata").load("path/to/geoparquet") + * }}} + */ +class GeoParquetMetadataDataSource extends FileDataSourceV2 with DataSourceRegister { + override val shortName: String = "geoparquet.metadata" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + None, + fallbackFileFormat) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala new file mode 100644 index 00000000000..683160e93b2 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.ParquetReadOptions +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods.{compact, render} + +case class GeoParquetMetadataPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: FileSourceOptions, + filters: Seq[Filter]) + extends FilePartitionReaderFactory { + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { + val iter = GeoParquetMetadataPartitionReaderFactory.readFile( + broadcastedConf.value.value, + partitionedFile, + readDataSchema) + val fileReader = new PartitionReaderFromIterator[InternalRow](iter) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFile.partitionValues) + } +} + +object GeoParquetMetadataPartitionReaderFactory { + private def readFile( + configuration: Configuration, + partitionedFile: PartitionedFile, + readDataSchema: StructType): Iterator[InternalRow] = { + + val inputFile = HadoopInputFile.fromPath(partitionedFile.toPath, configuration) + val inputStream = inputFile.newStream() + + val footer = ParquetFileReader + .readFooter(inputFile, ParquetReadOptions.builder().build(), inputStream) + + val filePath = partitionedFile.toPath.toString + val metadata = footer.getFileMetaData.getKeyValueMetaData + val row = GeoParquetMetaData.parseKeyValueMetaData(metadata) match { + case Some(geo) => + val geoColumnsMap = geo.columns.map { case (columnName, columnMetadata) => + implicit val formats: org.json4s.Formats = DefaultFormats + import org.json4s.jackson.Serialization + val columnMetadataFields: Array[Any] = Array( + UTF8String.fromString(columnMetadata.encoding), + new GenericArrayData(columnMetadata.geometryTypes.map(UTF8String.fromString).toArray), + new GenericArrayData(columnMetadata.bbox.toArray), + columnMetadata.crs + .map(projjson => UTF8String.fromString(compact(render(projjson)))) + .getOrElse(UTF8String.fromString("")), + columnMetadata.covering + .map(covering => UTF8String.fromString(Serialization.write(covering))) + .orNull) + val columnMetadataStruct = new GenericInternalRow(columnMetadataFields) + UTF8String.fromString(columnName) -> columnMetadataStruct + } + val fields: Array[Any] = Array( + UTF8String.fromString(filePath), + UTF8String.fromString(geo.version.orNull), + UTF8String.fromString(geo.primaryColumn), + ArrayBasedMapData(geoColumnsMap)) + new GenericInternalRow(fields) + case None => + // Not a GeoParquet file, return a row with null metadata values. + val fields: Array[Any] = Array(UTF8String.fromString(filePath), null, null, null) + new GenericInternalRow(fields) + } + Iterator(pruneBySchema(row, GeoParquetMetadataTable.schema, readDataSchema)) + } + + private def pruneBySchema( + row: InternalRow, + schema: StructType, + readDataSchema: StructType): InternalRow = { + // Projection push down for nested fields is not enabled, so this very simple implementation is enough. + val values: Array[Any] = readDataSchema.fields.map { field => + val index = schema.fieldIndex(field.name) + row.get(index, field.dataType) + } + new GenericInternalRow(values) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala new file mode 100644 index 00000000000..d7719d87dad --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.collection.JavaConverters._ + +case class GeoParquetMetadataScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + val fileSourceOptions = new FileSourceOptions(caseSensitiveMap) + GeoParquetMetadataPartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + fileSourceOptions, + pushedFilters) + } + + override def isSplitable(path: Path): Boolean = false + + override def getFileUnSplittableReason(path: Path): String = + "Reading parquet file metadata does not require splitting the file" +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala new file mode 100644 index 00000000000..c60369e1087 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class GeoParquetMetadataScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + override def build(): Scan = { + GeoParquetMetadataScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala new file mode 100644 index 00000000000..845764fae55 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class GeoParquetMetadataTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def formatName: String = "GeoParquet Metadata" + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + Some(GeoParquetMetadataTable.schema) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new GeoParquetMetadataScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) +} + +object GeoParquetMetadataTable { + private val columnMetadataType = StructType( + Seq( + StructField("encoding", StringType, nullable = true), + StructField("geometry_types", ArrayType(StringType), nullable = true), + StructField("bbox", ArrayType(DoubleType), nullable = true), + StructField("crs", StringType, nullable = true), + StructField("covering", StringType, nullable = true))) + + private val columnsType = MapType(StringType, columnMetadataType, valueContainsNull = false) + + val schema: StructType = StructType( + Seq( + StructField("path", StringType, nullable = false), + StructField("version", StringType, nullable = true), + StructField("primary_column", StringType, nullable = true), + StructField("columns", columnsType, nullable = true))) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala new file mode 100644 index 00000000000..30a3e652cf2 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import scala.jdk.CollectionConverters._ + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata +import org.apache.spark.sql.types.StructType + +/** + * A physical plan that evaluates a [[PythonUDF]]. + */ +case class SedonaArrowEvalPythonExec( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: SparkPlan, + evalType: Int) + extends EvalPythonExec + with PythonSQLMetrics { + + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + private[this] val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } + + override protected def evaluatorFactory: EvalPythonEvaluatorFactory = { + new SedonaArrowEvalPythonEvaluatorFactory( + child.output, + udfs, + output, + conf.arrowMaxRecordsPerBatch, + evalType, + conf.sessionLocalTimeZone, + conf.arrowUseLargeVarTypes, + ArrowPythonRunner.getPythonRunnerConfMap(conf), + pythonMetrics, + jobArtifactUUID, + sessionUUID, + conf.pythonUDFProfiler) + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + +class SedonaArrowEvalPythonEvaluatorFactory( + childOutput: Seq[Attribute], + udfs: Seq[PythonUDF], + output: Seq[Attribute], + batchSize: Int, + evalType: Int, + sessionLocalTimeZone: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + sessionUUID: Option[String], + profiler: Option[String]) + extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { + + override def evaluate( + funcs: Seq[(ChainedPythonFunctions, Long)], + argMetas: Array[Array[ArgumentMetadata]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + val batchIter = Iterator(iter) + + val pyRunner = new ArrowPythonWithNamedArgumentRunner( + funcs, + evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, + argMetas, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + sessionUUID, + profiler) with BatchedPythonArrowInput + val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context) + + columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala new file mode 100644 index 00000000000..3d3301580cc --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionSet, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Subquery} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_UDF + +import scala.collection.mutable + +// That rule extracts scalar Python UDFs, currently Apache Spark has +// assert on types which blocks using the vectorized udfs with geometry type +class ExtractSedonaUDFRule extends Rule[LogicalPlan] with Logging { + + private def hasScalarPythonUDF(e: Expression): Boolean = { + e.exists(PythonUDF.isScalarPythonUDF) + } + + @scala.annotation.tailrec + private def canEvaluateInPython(e: PythonUDF): Boolean = { + e.children match { + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) + case children => !children.exists(hasScalarPythonUDF) + } + } + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && e + .asInstanceOf[PythonUDF] + .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF + } + + private def collectEvaluableUDFsFromExpressions( + expressions: Seq[Expression]): Seq[PythonUDF] = { + + var firstVisitedScalarUDFEvalType: Option[Int] = None + + def canChainUDF(evalType: Int): Boolean = { + evalType == firstVisitedScalarUDFEvalType.get + } + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF + if isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && firstVisitedScalarUDFEvalType.isEmpty => + firstVisitedScalarUDFEvalType = Some(udf.evalType) + Seq(udf) + case udf: PythonUDF + if isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && canChainUDF(udf.evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs) + } + + expressions.flatMap(collectEvaluableUDFs) + } + + private var hasFailedBefore: Boolean = false + + def apply(plan: LogicalPlan): LogicalPlan = plan match { + case s: Subquery if s.correlated => plan + + case _ => + try { + plan.transformUpWithPruning(_.containsPattern(PYTHON_UDF)) { + case p: SedonaArrowEvalPython => p + + case plan: LogicalPlan => extract(plan) + } + } catch { + case e: Throwable => + if (!hasFailedBefore) { + log.warn( + s"Vectorized UDF feature won't be available due to plan transformation error.") + log.warn( + s"Failed to extract Sedona UDFs from plan: ${plan.treeString}\n" + + s"Exception: ${e.getMessage}", + e) + hasFailedBefore = true + } + plan + } + } + + private def canonicalizeDeterministic(u: PythonUDF) = { + if (u.deterministic) { + u.canonicalized.asInstanceOf[PythonUDF] + } else { + u + } + } + + private def extract(plan: LogicalPlan): LogicalPlan = { + val udfs = ExpressionSet(collectEvaluableUDFsFromExpressions(plan.expressions)) + .filter(udf => udf.references.subsetOf(plan.inputSet)) + .toSeq + .asInstanceOf[Seq[PythonUDF]] + + udfs match { + case Seq() => plan + case _ => resolveUDFs(plan, udfs) + } + } + + def resolveUDFs(plan: LogicalPlan, udfs: Seq[PythonUDF]): LogicalPlan = { + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + + val newChildren = adjustAttributeMap(plan, udfs, attributeMap) + + udfs.map(canonicalizeDeterministic).filterNot(attributeMap.contains).foreach { udf => + throw new IllegalStateException( + s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } + + val rewritten = plan.withNewChildren(newChildren).transformExpressions { case p: PythonUDF => + attributeMap.getOrElse(canonicalizeDeterministic(p), p) + } + + val newPlan = extract(rewritten) + if (newPlan.output != plan.output) { + Project(plan.output, newPlan) + } else { + newPlan + } + } + + def adjustAttributeMap( + plan: LogicalPlan, + udfs: Seq[PythonUDF], + attributeMap: mutable.HashMap[PythonUDF, Expression]): Seq[LogicalPlan] = { + plan.children.map { child => + val validUdfs = udfs.filter { udf => + udf.references.subsetOf(child.outputSet) + } + + if (validUdfs.nonEmpty) { + require( + validUdfs.forall(isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") + + val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() + } + + val evalTypes = validUdfs.map(_.evalType).toSet + if (evalTypes.size != 1) { + throw new IllegalStateException( + "Expected udfs have the same evalType but got different evalTypes: " + + evalTypes.mkString(",")) + } + val evalType = evalTypes.head + val evaluation = evalType match { + case PythonEvalType.SQL_SCALAR_SEDONA_UDF => + SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType) + case _ => + throw new IllegalStateException("Unexpected UDF evalType") + } + + attributeMap ++= validUdfs.map(canonicalizeDeterministic).zip(resultAttrs) + evaluation + } else { + child + } + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala new file mode 100644 index 00000000000..7600ece5079 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython, LogicalPlan} + +case class SedonaArrowEvalPython( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: LogicalPlan, + evalType: Int) + extends BaseEvalPython { + override protected def withNewChildInternal(newChild: LogicalPlan): SedonaArrowEvalPython = + copy(child = newChild) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala new file mode 100644 index 00000000000..21e389ee567 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.SedonaArrowEvalPythonExec +import org.apache.spark.sql.types.StructType + +import scala.collection.JavaConverters.asScalaIteratorConverter + +// We use custom Strategy to avoid Apache Spark assert on types, we +// can consider extending this to support other engines working with +// arrow data +class SedonaArrowStrategy extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case SedonaArrowEvalPython(udfs, output, child, evalType) => + SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil + case _ => Nil + } +} diff --git a/spark/spark-4.1/src/test/resources/log4j2.properties b/spark/spark-4.1/src/test/resources/log4j2.properties new file mode 100644 index 00000000000..683ecd32f25 --- /dev/null +++ b/spark/spark-4.1/src/test/resources/log4j2.properties @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Set everything to be logged to the file target/unit-tests.log +rootLogger.level = info +rootLogger.appenderRef.file.ref = File + +appender.file.type = File +appender.file.name = File +appender.file.fileName = target/unit-tests.log +appender.file.append = true +appender.file.layout.type = PatternLayout +appender.file.layout.pattern = %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex + +# Ignore messages below warning level from Jetty, because it's a bit verbose +logger.jetty.name = org.sparkproject.jetty +logger.jetty.level = warn diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala new file mode 100644 index 00000000000..66fa147bc58 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import io.minio.{MakeBucketArgs, MinioClient, PutObjectArgs} +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{BinaryType, BooleanType, DateType, DoubleType, IntegerType, StringType, StructField, StructType, TimestampType} +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.testcontainers.containers.MinIOContainer + +import java.io.FileInputStream +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +class GeoPackageReaderTest extends TestBaseScala with Matchers { + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + import sparkSession.implicits._ + + val path: String = resourceFolder + "geopackage/example.gpkg" + val polygonsPath: String = resourceFolder + "geopackage/features.gpkg" + val rasterPath: String = resourceFolder + "geopackage/raster.gpkg" + val wktReader = new org.locationtech.jts.io.WKTReader() + val wktWriter = new org.locationtech.jts.io.WKTWriter() + + val expectedFeatureSchema = StructType( + Seq( + StructField("id", IntegerType, true), + StructField("geometry", GeometryUDT(), true), + StructField("text", StringType, true), + StructField("real", DoubleType, true), + StructField("boolean", BooleanType, true), + StructField("blob", BinaryType, true), + StructField("integer", IntegerType, true), + StructField("text_limited", StringType, true), + StructField("blob_limited", BinaryType, true), + StructField("date", DateType, true), + StructField("datetime", TimestampType, true))) + + describe("Reading GeoPackage metadata") { + it("should read GeoPackage metadata") { + val df = sparkSession.read + .format("geopackage") + .option("showMetadata", "true") + .load(path) + + df.where("data_type = 'tiles'").show(false) + + df.count shouldEqual 34 + } + } + + describe("Reading Vector data") { + it("should read GeoPackage - point1") { + val df = readFeatureData("point1") + df.schema shouldEqual expectedFeatureSchema + + df.count() shouldEqual 4 + + val firstElement = df.collectAsList().get(0).toSeq + + val expectedValues = Seq( + 1, + wktReader.read(POINT_1), + "BIT Systems", + 4519.866024037493, + true, + Array(48, 99, 57, 54, 49, 56, 55, 54, 45, 98, 102, 100, 52, 45, 52, 102, 52, 48, 45, 97, + 49, 102, 101, 45, 55, 49, 55, 101, 57, 100, 50, 98, 48, 55, 98, 101), + 3, + "bcd5a36f-16dc-4385-87be-b40353848597", + Array(49, 50, 53, 50, 97, 99, 98, 52, 45, 57, 54, 54, 52, 45, 52, 101, 51, 50, 45, 57, 54, + 100, 101, 45, 56, 48, 54, 101, 101, 48, 101, 101, 49, 102, 57, 48), + Date.valueOf("2023-09-19"), + Timestamp.valueOf("2023-09-19 11:24:15.695")) + + firstElement should contain theSameElementsAs expectedValues + } + + it("should read GeoPackage - line1") { + val df = readFeatureData("line1") + .withColumn("datetime", expr("from_utc_timestamp(datetime, 'UTC')")) + + df.schema shouldEqual expectedFeatureSchema + + df.count() shouldEqual 3 + + val firstElement = df.collectAsList().get(0).toSeq + + firstElement should contain theSameElementsAs Seq( + 1, + wktReader.read(LINESTRING_1), + "East Lockheed Drive", + 1990.5159635296877, + false, + Array(54, 97, 98, 100, 98, 51, 97, 56, 45, 54, 53, 101, 48, 45, 52, 55, 48, 54, 45, 56, + 50, 52, 48, 45, 51, 57, 48, 55, 99, 50, 102, 102, 57, 48, 99, 55), + 1, + "13dd91dc-3b7d-4d8d-a0ca-b3afb8e31c3d", + Array(57, 54, 98, 102, 56, 99, 101, 56, 45, 102, 48, 54, 49, 45, 52, 55, 99, 48, 45, 97, + 98, 48, 101, 45, 97, 99, 50, 52, 100, 98, 50, 97, 102, 50, 50, 54), + Date.valueOf("2023-09-19"), + Timestamp.valueOf("2023-09-19 11:24:15.716")) + } + + it("should read GeoPackage - polygon1") { + val df = readFeatureData("polygon1") + df.count shouldEqual 3 + df.schema shouldEqual expectedFeatureSchema + + df.select("geometry").collectAsList().get(0).toSeq should contain theSameElementsAs Seq( + wktReader.read(POLYGON_1)) + } + + it("should read GeoPackage - geometry1") { + val df = readFeatureData("geometry1") + df.count shouldEqual 10 + df.schema shouldEqual expectedFeatureSchema + + df.selectExpr("ST_ASTEXT(geometry)") + .as[String] + .collect() should contain theSameElementsAs Seq( + POINT_1, + POINT_2, + POINT_3, + POINT_4, + LINESTRING_1, + LINESTRING_2, + LINESTRING_3, + POLYGON_1, + POLYGON_2, + POLYGON_3) + } + + it("should read polygon with envelope data") { + val tables = Table( + ("tableName", "expectedCount"), + ("GB_Hex_5km_GS_CompressibleGround_v8", 4233), + ("GB_Hex_5km_GS_Landslides_v8", 4228), + ("GB_Hex_5km_GS_RunningSand_v8", 4233), + ("GB_Hex_5km_GS_ShrinkSwell_v8", 4233), + ("GB_Hex_5km_GS_SolubleRocks_v8", 4295)) + + forAll(tables) { (tableName: String, expectedCount: Int) => + val df = sparkSession.read + .format("geopackage") + .option("tableName", tableName) + .load(polygonsPath) + + df.count() shouldEqual expectedCount + } + } + + it("should handle datetime fields without timezone information") { + // This test verifies the fix for DateTimeParseException when reading + // GeoPackage files with datetime fields that don't include timezone info + val testFilePath = resourceFolder + "geopackage/test_datetime_issue.gpkg" + + // Test reading the test_features table with problematic datetime formats + val df = sparkSession.read + .format("geopackage") + .option("tableName", "test_features") + .load(testFilePath) + + // The test should not throw DateTimeParseException when reading datetime fields + noException should be thrownBy { + df.select("created_at", "updated_at").collect() + } + + // Verify that datetime fields are properly parsed as TimestampType + df.schema.fields.find(_.name == "created_at").get.dataType shouldEqual TimestampType + df.schema.fields.find(_.name == "updated_at").get.dataType shouldEqual TimestampType + + // Verify that we can read the datetime values + val datetimeValues = df.select("created_at", "updated_at").collect() + datetimeValues should not be empty + + // Verify that datetime values are valid timestamps + datetimeValues.foreach { row => + val createdTimestamp = row.getAs[Timestamp]("created_at") + val updatedTimestamp = row.getAs[Timestamp]("updated_at") + createdTimestamp should not be null + updatedTimestamp should not be null + createdTimestamp.getTime should be > 0L + updatedTimestamp.getTime should be > 0L + } + + // Test showMetadata option with the same file + noException should be thrownBy { + val metadataDf = sparkSession.read + .format("geopackage") + .option("showMetadata", "true") + .load(testFilePath) + metadataDf.select("last_change").collect() + } + } + } + + describe("GeoPackage Raster Data Test") { + it("should read") { + val fractions = + Table( + ("tableName", "channelNumber", "expectedSum"), + ("point1_tiles", 4, 466591.0), + ("line1_tiles", 4, 5775976.0), + ("polygon1_tiles", 4, 1.1269871e7), + ("geometry1_tiles", 4, 2.6328442e7), + ("point2_tiles", 4, 137456.0), + ("line2_tiles", 4, 6701101.0), + ("polygon2_tiles", 4, 5.1170714e7), + ("geometry2_tiles", 4, 1.6699823e7), + ("bit_systems", 1, 6.5561879e7), + ("nga", 1, 6.8078856e7), + ("bit_systems_wgs84", 1, 7.7276934e7), + ("nga_pc", 1, 2.90590616e8), + ("bit_systems_world", 1, 7.7276934e7), + ("nga_pc_world", 1, 2.90590616e8)) + + forAll(fractions) { (tableName: String, channelNumber: Int, expectedSum: Double) => + { + val df = readFeatureData(tableName) + val calculatedSum = df + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${channelNumber}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum.collect().head shouldEqual expectedSum + } + } + } + + it("should be able to read complex raster data") { + val df = sparkSession.read + .format("geopackage") + .option("tableName", "AuroraAirportNoise") + .load(rasterPath) + + df.show(5) + + val calculatedSum = df + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${1}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum.first() shouldEqual 2.027126e7 + + val df2 = sparkSession.read + .format("geopackage") + .option("tableName", "LiquorLicenseDensity") + .load(rasterPath) + + val calculatedSum2 = df2 + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${1}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum2.first() shouldEqual 2.882028e7 + } + + } + + describe("Reading from S3") { + it("should be able to read files from S3") { + val container = new MinIOContainer("minio/minio:latest") + + container.start() + + val minioClient = createMinioClient(container) + val makeBucketRequest = MakeBucketArgs + .builder() + .bucket("sedona") + .build() + + minioClient.makeBucket(makeBucketRequest) + + adjustSparkSession(sparkSessionMinio, container) + + val inputPath: String = prepareFile("example.geopackage", path, minioClient) + + sparkSessionMinio.read + .format("geopackage") + .option("showMetadata", "true") + .load(inputPath) + .count shouldEqual 34 + + val df = sparkSession.read + .format("geopackage") + .option("tableName", "point1") + .load(inputPath) + + df.count shouldEqual 4 + + val inputPathLarger: String = prepareFiles((1 to 300).map(_ => path).toArray, minioClient) + + val dfLarger = sparkSessionMinio.read + .format("geopackage") + .option("tableName", "point1") + .load(inputPathLarger) + + dfLarger.count shouldEqual 300 * 4 + + container.stop() + } + } + + private def readFeatureData(tableName: String): DataFrame = { + sparkSession.read + .format("geopackage") + .option("tableName", tableName) + .load(path) + } + + private def prepareFiles(paths: Array[String], minioClient: MinioClient): String = { + val key = "geopackage" + + paths.foreach(path => { + val fis = new FileInputStream(path); + putFileIntoBucket( + "sedona", + s"${key}/${scala.util.Random.nextInt(1000000000)}.geopackage", + fis, + minioClient) + }) + + s"s3a://sedona/$key" + } + + private def prepareFile(name: String, path: String, minioClient: MinioClient): String = { + val fis = new FileInputStream(path); + putFileIntoBucket("sedona", name, fis, minioClient) + + s"s3a://sedona/$name" + } + + private val POINT_1 = "POINT (-104.801918 39.720014)" + private val POINT_2 = "POINT (-104.802987 39.717703)" + private val POINT_3 = "POINT (-104.807496 39.714085)" + private val POINT_4 = "POINT (-104.79948 39.714729)" + private val LINESTRING_1 = + "LINESTRING (-104.800614 39.720721, -104.802174 39.720726, -104.802584 39.72066, -104.803088 39.720477, -104.803474 39.720209)" + private val LINESTRING_2 = + "LINESTRING (-104.809612 39.718379, -104.806638 39.718372, -104.806236 39.718439, -104.805939 39.718536, -104.805654 39.718677, -104.803652 39.720095)" + private val LINESTRING_3 = + "LINESTRING (-104.806344 39.722425, -104.805854 39.722634, -104.805656 39.722647, -104.803749 39.722641, -104.803769 39.721849, -104.803806 39.721725, -104.804382 39.720865)" + private val POLYGON_1 = + "POLYGON ((-104.802246 39.720343, -104.802246 39.719753, -104.802183 39.719754, -104.802184 39.719719, -104.802138 39.719694, -104.802097 39.719691, -104.802096 39.719648, -104.801646 39.719648, -104.801644 39.719722, -104.80155 39.719723, -104.801549 39.720207, -104.801648 39.720207, -104.801648 39.720341, -104.802246 39.720343))" + private val POLYGON_2 = + "POLYGON ((-104.802259 39.719604, -104.80226 39.71955, -104.802281 39.719416, -104.802332 39.719372, -104.802081 39.71924, -104.802044 39.71929, -104.802027 39.719278, -104.802044 39.719229, -104.801785 39.719129, -104.801639 39.719413, -104.801649 39.719472, -104.801694 39.719524, -104.801753 39.71955, -104.80175 39.719606, -104.80194 39.719606, -104.801939 39.719555, -104.801977 39.719556, -104.801979 39.719606, -104.802259 39.719604), (-104.80213 39.71944, -104.802133 39.71949, -104.802148 39.71949, -104.80218 39.719473, -104.802187 39.719456, -104.802182 39.719439, -104.802088 39.719387, -104.802047 39.719427, -104.801858 39.719342, -104.801883 39.719294, -104.801832 39.719284, -104.801787 39.719298, -104.801763 39.719331, -104.801823 39.719352, -104.80179 39.71942, -104.801722 39.719404, -104.801715 39.719445, -104.801748 39.719484, -104.801809 39.719494, -104.801816 39.719439, -104.80213 39.71944))" + private val POLYGON_3 = + "POLYGON ((-104.802867 39.718122, -104.802369 39.717845, -104.802571 39.71763, -104.803066 39.717909, -104.802867 39.718122))" +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala new file mode 100644 index 00000000000..01306c1b452 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.spark.sql.Row +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.scalatest.BeforeAndAfterAll + +import java.util.Collections +import scala.collection.JavaConverters._ + +class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation: String = resourceFolder + "geoparquet/" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + describe("GeoParquet Metadata tests") { + it("Reading GeoParquet Metadata") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.collect() + assert(metadataArray.length > 1) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("version") == "1.0.0-dev")) + assert(metadataArray.exists(_.getAs[String]("primary_column") == "geometry")) + assert(metadataArray.exists { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + columnsMap != null && columnsMap + .containsKey("geometry") && columnsMap.get("geometry").isInstanceOf[Row] + }) + assert(metadataArray.forall { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + if (columnsMap == null || !columnsMap.containsKey("geometry")) true + else { + val columnMetadata = columnsMap.get("geometry").asInstanceOf[Row] + columnMetadata.getAs[String]("encoding") == "WKB" && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("bbox")) + .asScala + .forall(_.isInstanceOf[Double]) && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("geometry_types")) + .asScala + .forall(_.isInstanceOf[String]) && + columnMetadata.getAs[String]("crs").nonEmpty && + columnMetadata.getAs[String]("crs") != "null" + } + }) + } + + it("Reading GeoParquet Metadata with column pruning") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df + .selectExpr("path", "substring(primary_column, 1, 2) AS partial_primary_column") + .collect() + assert(metadataArray.length > 1) + assert(metadataArray.forall(_.length == 2)) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("partial_primary_column") == "ge")) + } + + it("Reading GeoParquet Metadata of plain parquet files") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.where("path LIKE '%plain.parquet'").collect() + assert(metadataArray.nonEmpty) + assert(metadataArray.forall(_.getAs[String]("path").endsWith("plain.parquet"))) + assert(metadataArray.forall(_.getAs[String]("version") == null)) + assert(metadataArray.forall(_.getAs[String]("primary_column") == null)) + assert(metadataArray.forall(_.getAs[String]("columns") == null)) + } + + it("Read GeoParquet without CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "") + } + + it("Read GeoParquet with null CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "null") + } + + it("Read GeoParquet with snake_case geometry column name and camelCase column name") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")) + assert(metadata.containsKey("geom_column_1")) + assert(!metadata.containsKey("geoColumn1")) + assert(metadata.containsKey("geomColumn2")) + assert(!metadata.containsKey("geom_column2")) + assert(!metadata.containsKey("geom_column_2")) + } + + it("Read GeoParquet with covering metadata") { + val dfMeta = sparkSession.read + .format("geoparquet.metadata") + .load(geoparquetdatalocation + "/example-1.1.0.parquet") + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + val covering = metadata.getAs[String]("covering") + assert(covering.nonEmpty) + Seq("bbox", "xmin", "ymin", "xmax", "ymax").foreach { key => + assert(covering contains key) + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala new file mode 100644 index 00000000000..6f873d0a087 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.scalatest.matchers.must.Matchers.be +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.prop.TableDrivenPropertyChecks + +/** + * Test suite for testing Sedona SQL support. + */ +class SQLSyntaxTestScala extends TestBaseScala with TableDrivenPropertyChecks { + + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession.conf.set("spark.sql.legacy.createHiveTableByDefault", "false") + } + + describe("Table creation DDL tests") { + + it("should be able to create a regular table without geometry column should work") { + sparkSession.sql("DROP TABLE IF EXISTS T_TEST_REGULAR") + sparkSession.sql("CREATE TABLE IF NOT EXISTS T_TEST_REGULAR (INT_COL INT)") + sparkSession.catalog.tableExists("T_TEST_REGULAR") should be(true) + sparkSession.sql("DROP TABLE IF EXISTS T_TEST_REGULAR") + sparkSession.catalog.tableExists("T_TEST_REGULAR") should be(false) + } + + it( + "should be able to create a regular table with geometry column should work without a workaround") { + try { + sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } + } + + it( + "should be able to create a regular table with regular and geometry column should work without a workaround") { + try { + sparkSession.sql( + "CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 00000000000..275bc3282f9 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,784 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.locationtech.jts.io.{WKTReader, WKTWriter} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files +import scala.collection.mutable + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(!row.isNullAt(row.fieldIndex("id"))) + } + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + var rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + // Using partition filters + rows = shapefileDf.where("part = 2").collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT()))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + + it("should read shapes of various types") { + // There are multiple directories under shapefiles/shapetypes, each containing a shapefile. + // We'll iterate over each directory and read the shapefile within it. + val shapeTypesDir = new File(resourceFolder + "shapefiles/shapetypes") + val shapeTypeDirs = shapeTypesDir.listFiles().filter(_.isDirectory) + shapeTypeDirs.foreach { shapeTypeDir => + val fileName = shapeTypeDir.getName + val hasZ = fileName.endsWith("zm") || fileName.endsWith("z") + val hasM = fileName.endsWith("zm") || fileName.endsWith("m") + val shapeType = + if (fileName.startsWith("point")) "POINT" + else if (fileName.startsWith("linestring")) "LINESTRING" + else if (fileName.startsWith("multipoint")) "MULTIPOINT" + else "POLYGON" + val expectedWktPrefix = + if (!hasZ && !hasM) shapeType + else { + shapeType + " " + (if (hasZ) "Z" else "") + (if (hasM) "M" else "") + } + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(shapeTypeDir.getAbsolutePath) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + val rows = shapefileDf.collect() + assert(rows.length > 0) + + // Validate the geometry type and WKT prefix + val wktWriter = new WKTWriter(4) + val rowsMap = mutable.Map[String, Geometry]() + rows.foreach { row => + val id = row.getAs[String]("id") + val geom = row.getAs[Geometry]("geometry") + val wkt = wktWriter.write(geom) + assert(wkt.startsWith(expectedWktPrefix)) + assert(geom != null) + rowsMap.put(id, geom) + } + + // Validate the geometry values by reading the CSV file containing the same data + val csvDf = sparkSession.read + .format("csv") + .option("header", "true") + .load(shapeTypeDir.getAbsolutePath + "/*.csv") + val wktReader = new WKTReader() + csvDf.collect().foreach { row => + val id = row.getAs[String]("id") + val wkt = row.getAs[String]("wkt") + val geom = wktReader.read(wkt) + assert(rowsMap(id).equals(geom)) + } + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala new file mode 100644 index 00000000000..2e3e9742222 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import io.minio.{MinioClient, PutObjectArgs} +import org.apache.log4j.{Level, Logger} +import org.apache.sedona.spark.SedonaContext +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.scalatest.{BeforeAndAfterAll, FunSpec} +import org.testcontainers.containers.MinIOContainer + +import java.io.FileInputStream + +import java.util.concurrent.ThreadLocalRandom + +trait TestBaseScala extends FunSpec with BeforeAndAfterAll { + Logger.getRootLogger().setLevel(Level.WARN) + Logger.getLogger("org.apache").setLevel(Level.WARN) + Logger.getLogger("com").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) + Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + + val keyParserExtension = "spark.sedona.enableParserExtension" + val warehouseLocation = System.getProperty("user.dir") + "/target/" + val sparkSession = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + // We need to be explicit about broadcasting in tests. + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") + .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) + // Disable Spark 4.1+ native geospatial functions that shadow Sedona's ST functions + .config("spark.sql.geospatial.enabled", "false") + .getOrCreate() + + val sparkSessionMinio = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.0") + .config( + "spark.hadoop.fs.s3a.aws.credentials.provider", + "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider") + .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .getOrCreate() + + val resourceFolder = System.getProperty("user.dir") + "/../common/src/test/resources/" + + override def beforeAll(): Unit = { + SedonaContext.create(sparkSession) + } + + override def afterAll(): Unit = { + // SedonaSQLRegistrator.dropAll(spark) + // spark.stop + } + + def loadCsv(path: String): DataFrame = { + sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) + } + + def withConf[T](conf: Map[String, String])(f: => T): T = { + val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key)) + conf.foreach { case (key, value) => sparkSession.conf.set(key, value) } + try { + f + } finally { + oldConf.foreach { case (key, value) => + value match { + case Some(v) => sparkSession.conf.set(key, v) + case None => sparkSession.conf.unset(key) + } + } + } + } + + def putFileIntoBucket( + bucketName: String, + key: String, + stream: FileInputStream, + client: MinioClient): Unit = { + val objectArguments = PutObjectArgs + .builder() + .bucket(bucketName) + .`object`(key) + .stream(stream, stream.available(), -1) + .build() + + client.putObject(objectArguments) + } + + def createMinioClient(container: MinIOContainer): MinioClient = { + MinioClient + .builder() + .endpoint(container.getS3URL) + .credentials(container.getUserName, container.getPassword) + .build() + } + + def adjustSparkSession(sparkSession: SparkSession, container: MinIOContainer): Unit = { + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.endpoint", container.getS3URL) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.access.key", container.getUserName) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.secret.key", container.getPassword) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.connection.timeout", "2000") + + sparkSession.sparkContext.hadoopConfiguration.set("spark.sql.debug.maxToStringFields", "100") + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.path.style.access", "true") + sparkSession.sparkContext.hadoopConfiguration + .set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala new file mode 100644 index 00000000000..8d41848de98 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.TestBaseScala +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction +import org.locationtech.jts.io.WKTReader +import org.scalatest.matchers.should.Matchers + +class StrategySuite extends TestBaseScala with Matchers { + val wktReader = new WKTReader() + + val spark: SparkSession = { + sparkSession.sparkContext.setLogLevel("ALL") + sparkSession + } + + import spark.implicits._ + + it("sedona geospatial UDF") { + val df = Seq( + (1, "value", wktReader.read("POINT(21 52)")), + (2, "value1", wktReader.read("POINT(20 50)")), + (3, "value2", wktReader.read("POINT(20 49)")), + (4, "value3", wktReader.read("POINT(20 48)")), + (5, "value4", wktReader.read("POINT(20 47)"))) + .toDF("id", "value", "geom") + .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + + df.count shouldEqual 5 + + df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + .as[String] + .collect() should contain theSameElementsAs Seq( + "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", + "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", + "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", + "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", + "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala new file mode 100644 index 00000000000..62733169288 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF +import org.apache.spark.TestUtils +import org.apache.spark.api.python._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.util.Utils + +import java.io.File +import java.nio.file.{Files, Paths} +import scala.sys.process.Process +import scala.jdk.CollectionConverters._ + +object ScalarUDF { + + val pythonExec: String = { + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3")) + if (TestUtils.testCommandAvailable(pythonExec)) { + pythonExec + } else { + "python" + } + } + + private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") + protected lazy val sparkHome: String = { + sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) + } + + private lazy val py4jPath = + Paths.get(sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath + private[spark] lazy val pysparkPythonPath = s"$py4jPath" + + private lazy val isPythonAvailable: Boolean = TestUtils.testCommandAvailable(pythonExec) + + lazy val pythonVer: String = if (isPythonAvailable) { + Process( + Seq(pythonExec, "-c", "import sys; print('%d.%d' % sys.version_info[:2])"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!.trim() + } else { + throw new RuntimeException(s"Python executable [$pythonExec] is unavailable.") + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) + finally Utils.deleteRecursively(path) + } + + val pandasFunc: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + println(path) + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import IntegerType + |from shapely.geometry import Point + |from sedona.sql.types import GeometryType + |from pyspark.serializers import CloudPickleSerializer + |from sedona.utils import geometry_serde + |from shapely import box + |f = open('$path', 'wb'); + |def w(x): + | def apply_function(w): + | geom, offset = geometry_serde.deserialize(w) + | bounds = geom.buffer(1).bounds + | x = box(*bounds) + | return geometry_serde.serialize(x) + | return x.apply(apply_function) + |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } + + private val workerEnv = new java.util.HashMap[String, String]() + workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + + val geoPandasScalaFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "geospatial_udf", + func = SimplePythonFunction( + command = pandasFunc, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = GeometryUDT(), + pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, + udfDeterministic = true) +}