From b54b976e1df57b5165da04b5da0066b0ea4e1f2b Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Thu, 12 Feb 2026 11:38:57 -0800 Subject: [PATCH 01/14] [GH-2609] Support Spark 4.1 - Add sedona-spark-4.1 Maven profile (Spark 4.1.0, Scala 2.13.17, Hadoop 3.4.1) - Create spark/spark-4.1 module based on spark-4.0 - Fix Geometry import ambiguity (Spark 4.1 adds o.a.s.sql.types.Geometry) - Fix WritableColumnVector.setAllNull() removal (replaced by setMissing() in 4.1) - Add sessionUUID parameter to ArrowPythonWithNamedArgumentRunner (new in 4.1) - Update docs (maven-coordinates, platform, publish) - Update CI workflows (java, example, python, docker-build) --- .github/workflows/docker-build.yml | 5 +- .github/workflows/example.yml | 4 + .github/workflows/java.yml | 3 + .github/workflows/python.yml | 12 +- docs/community/publish.md | 19 +- docs/setup/maven-coordinates.md | 58 ++ docs/setup/platform.md | 28 +- pom.xml | 23 + spark/common/pom.xml | 17 + .../internal/ParquetColumnVector.java | 23 +- .../sedona_sql/expressions/Functions.scala | 1 + spark/pom.xml | 1 + spark/spark-4.1/.gitignore | 29 + spark/spark-4.1/pom.xml | 185 +++++ ...pache.spark.sql.sources.DataSourceRegister | 3 + .../geopackage/GeoPackageDataSource.scala | 73 ++ .../GeoPackagePartitionReader.scala | 107 +++ .../GeoPackagePartitionReaderFactory.scala | 88 ++ .../geopackage/GeoPackageScan.scala | 59 ++ .../geopackage/GeoPackageScanBuilder.scala | 58 ++ .../geopackage/GeoPackageTable.scala | 91 ++ .../shapefile/ShapefileDataSource.scala | 101 +++ .../shapefile/ShapefilePartition.scala | 27 + .../shapefile/ShapefilePartitionReader.scala | 287 +++++++ .../ShapefilePartitionReaderFactory.scala | 66 ++ .../shapefile/ShapefileReadOptions.scala | 45 + .../datasources/shapefile/ShapefileScan.scala | 118 +++ .../shapefile/ShapefileScanBuilder.scala | 48 ++ .../shapefile/ShapefileTable.scala | 103 +++ .../shapefile/ShapefileUtils.scala | 202 +++++ .../sql/parser/SedonaSqlAstBuilder.scala | 39 + .../sedona/sql/parser/SedonaSqlParser.scala | 49 ++ .../GeoParquetMetadataDataSource.scala | 65 ++ ...arquetMetadataPartitionReaderFactory.scala | 122 +++ .../metadata/GeoParquetMetadataScan.scala | 69 ++ .../GeoParquetMetadataScanBuilder.scala | 47 ++ .../metadata/GeoParquetMetadataTable.scala | 70 ++ .../python/SedonaArrowEvalPythonExec.scala | 114 +++ .../spark/sql/udf/ExtractSedonaUDFRule.scala | 185 +++++ .../spark/sql/udf/SedonaArrowEvalPython.scala | 32 + .../spark/sql/udf/SedonaArrowStrategy.scala | 43 + .../src/test/resources/log4j2.properties | 31 + .../sedona/sql/GeoPackageReaderTest.scala | 369 +++++++++ .../sedona/sql/GeoParquetMetadataTests.scala | 152 ++++ .../sedona/sql/SQLSyntaxTestScala.scala | 72 ++ .../apache/sedona/sql/ShapefileTests.scala | 784 ++++++++++++++++++ .../org/apache/sedona/sql/TestBaseScala.scala | 129 +++ .../apache/spark/sql/udf/StrategySuite.scala | 59 ++ .../spark/sql/udf/TestScalarPandasUDF.scala | 122 +++ 49 files changed, 4413 insertions(+), 24 deletions(-) create mode 100644 spark/spark-4.1/.gitignore create mode 100644 spark/spark-4.1/pom.xml create mode 100644 spark/spark-4.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageDataSource.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReader.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReaderFactory.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScan.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScanBuilder.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageTable.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala create mode 100644 spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala create mode 100644 spark/spark-4.1/src/test/resources/log4j2.properties create mode 100644 spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala create mode 100644 spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala create mode 100644 spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala create mode 100644 spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala create mode 100644 spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala create mode 100644 spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala create mode 100644 spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 42306f9ee1d..a7988b56aa5 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -46,7 +46,7 @@ jobs: fail-fast: true matrix: os: ['ubuntu-latest', 'ubuntu-24.04-arm'] - spark: ['4.0.1'] + spark: ['4.0.1', '4.1.0'] include: - spark: 4.0.1 sedona: 'latest' @@ -54,6 +54,9 @@ jobs: - spark: 4.0.1 sedona: 1.8.0 geotools: '33.1' + - spark: 4.1.0 + sedona: 'latest' + geotools: '33.1' runs-on: ${{ matrix.os }} defaults: run: diff --git a/.github/workflows/example.yml b/.github/workflows/example.yml index 6d16137a839..4a93b05cc73 100644 --- a/.github/workflows/example.yml +++ b/.github/workflows/example.yml @@ -45,6 +45,10 @@ jobs: fail-fast: false matrix: include: + - spark: 4.1.0 + spark-compat: '4.1' + sedona: 1.9.0-SNAPSHOT + hadoop: 3.4.2 - spark: 4.0.1 spark-compat: '4.0' sedona: 1.8.0 diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 76834c9d63e..32d3d095daf 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -62,6 +62,9 @@ jobs: fail-fast: true matrix: include: + - spark: 4.1.0 + scala: 2.13.8 + jdk: '17' - spark: 4.0.0 scala: 2.13.8 jdk: '17' diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 05782bbae9e..ff3810d72b2 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -60,6 +60,14 @@ jobs: strategy: matrix: include: + - spark: '4.1.0' + scala: '2.13.8' + java: '17' + python: '3.11' + - spark: '4.1.0' + scala: '2.13.8' + java: '17' + python: '3.10' - spark: '4.0.0' scala: '2.13.8' java: '17' @@ -149,9 +157,9 @@ jobs: fi if [ "${SPARK_VERSION:0:1}" == "4" ]; then - # Spark 4.0 requires Python 3.9+, and we remove flink since it conflicts with pyspark 4.0 + # Spark 4.x requires Python 3.9+, and we remove flink since it conflicts with pyspark 4.x uv remove apache-flink --optional flink - uv add "pyspark==4.0.0; python_version >= '3.9'" + uv add "pyspark==${SPARK_VERSION}; python_version >= '3.9'" else # Install specific pyspark version matching matrix uv add pyspark==${SPARK_VERSION} diff --git a/docs/community/publish.md b/docs/community/publish.md index 4880e9e40df..348590f4207 100644 --- a/docs/community/publish.md +++ b/docs/community/publish.md @@ -119,13 +119,13 @@ rm -rf $LOCAL_DIR && git clone --depth 1 --branch $TAG $REPO_URL $LOCAL_DIR && c MAVEN_PLUGIN_VERSION="2.3.2" # Define Spark and Scala versions -declare -a SPARK_VERSIONS=("3.4" "3.5" "4.0") +declare -a SPARK_VERSIONS=("3.4" "3.5" "4.0" "4.1") declare -a SCALA_VERSIONS=("2.12" "2.13") # Function to get Java version for Spark version get_java_version() { local spark_version=$1 - if [[ "$spark_version" == "4.0" ]]; then + if [[ "$spark_version" == "4."* ]]; then echo "17" else echo "11" @@ -217,8 +217,8 @@ verify_java_version() { # Iterate through Spark and Scala versions for SPARK in "${SPARK_VERSIONS[@]}"; do for SCALA in "${SCALA_VERSIONS[@]}"; do - # Skip Spark 4.0 + Scala 2.12 combination as it's not supported - if [[ "$SPARK" == "4.0" && "$SCALA" == "2.12" ]]; then + # Skip Spark 4.0+ + Scala 2.12 combination as it's not supported + if [[ "$SPARK" == "4."* && "$SCALA" == "2.12" ]]; then echo "Skipping Spark $SPARK with Scala $SCALA (not supported)" continue fi @@ -286,7 +286,7 @@ mkdir apache-sedona-${SEDONA_VERSION}-bin # Function to get Java version for Spark version get_java_version() { local spark_version=$1 - if [[ "$spark_version" == "4.0" ]]; then + if [[ "$spark_version" == "4."* ]]; then echo "17" else echo "11" @@ -410,6 +410,15 @@ echo "Compiling for Spark 4.0 with Scala 2.13 using Java $JAVA_VERSION..." cd apache-sedona-${SEDONA_VERSION}-src && $MVN_WRAPPER clean && $MVN_WRAPPER install -DskipTests -Dspark=4.0 -Dscala=2.13 && cd .. cp apache-sedona-${SEDONA_VERSION}-src/spark-shaded/target/sedona-*${SEDONA_VERSION}.jar apache-sedona-${SEDONA_VERSION}-bin/ +# Compile for Spark 4.1 with Java 17 +JAVA_VERSION=$(get_java_version "4.1") +MVN_WRAPPER=$(create_mvn_wrapper $JAVA_VERSION) +verify_java_version $MVN_WRAPPER $JAVA_VERSION + +echo "Compiling for Spark 4.1 with Scala 2.13 using Java $JAVA_VERSION..." +cd apache-sedona-${SEDONA_VERSION}-src && $MVN_WRAPPER clean && $MVN_WRAPPER install -DskipTests -Dspark=4.1 -Dscala=2.13 && cd .. +cp apache-sedona-${SEDONA_VERSION}-src/spark-shaded/target/sedona-*${SEDONA_VERSION}.jar apache-sedona-${SEDONA_VERSION}-bin/ + # Clean up Maven wrappers rm -f /tmp/mvn-java11 /tmp/mvn-java17 diff --git a/docs/setup/maven-coordinates.md b/docs/setup/maven-coordinates.md index d57f73eebb4..b655ee0ff73 100644 --- a/docs/setup/maven-coordinates.md +++ b/docs/setup/maven-coordinates.md @@ -84,6 +84,22 @@ The optional GeoTools library is required if you want to use raster operators. V ``` + === "Spark 4.1 and Scala 2.12" + + ```xml + + org.apache.sedona + sedona-spark-shaded-4.1_2.12 + {{ sedona.current_version }} + + + + org.datasyslab + geotools-wrapper + {{ sedona.current_geotools }} + + ``` + !!! abstract "Sedona with Apache Spark and Scala 2.13" === "Spark 3.4 and Scala 2.13" @@ -133,6 +149,22 @@ The optional GeoTools library is required if you want to use raster operators. V ``` + === "Spark 4.1 and Scala 2.13" + + ```xml + + org.apache.sedona + sedona-spark-shaded-4.1_2.13 + {{ sedona.current_version }} + + + + org.datasyslab + geotools-wrapper + {{ sedona.current_geotools }} + + ``` + !!! abstract "Sedona with Apache Flink" === "Flink 1.12+ and Scala 2.12" @@ -223,6 +255,19 @@ The optional GeoTools library is required if you want to use raster operators. V {{ sedona.current_geotools }} ``` + === "Spark 4.1 and Scala 2.12" + ```xml + + org.apache.sedona + sedona-spark-4.1_2.12 + {{ sedona.current_version }} + + + org.datasyslab + geotools-wrapper + {{ sedona.current_geotools }} + + ``` !!! abstract "Sedona with Apache Spark and Scala 2.13" @@ -265,6 +310,19 @@ The optional GeoTools library is required if you want to use raster operators. V {{ sedona.current_geotools }} ``` + === "Spark 4.1 and Scala 2.13" + ```xml + + org.apache.sedona + sedona-spark-4.1_2.13 + {{ sedona.current_version }} + + + org.datasyslab + geotools-wrapper + {{ sedona.current_geotools }} + + ``` !!! abstract "Sedona with Apache Flink" diff --git a/docs/setup/platform.md b/docs/setup/platform.md index 9ea2abe0915..e3113ad9b10 100644 --- a/docs/setup/platform.md +++ b/docs/setup/platform.md @@ -22,28 +22,28 @@ Sedona binary releases are compiled by Java 11/17 and Scala 2.12/2.13 and tested **Java Requirements:** - Spark 3.4 & 3.5: Java 11 -- Spark 4.0: Java 17 +- Spark 4.0 & 4.1: Java 17 **Note:** Java 8 support is dropped since Sedona 1.8.0. Spark 3.3 support is dropped since Sedona 1.8.0. === "Sedona Scala/Java" - | | Spark 3.4| Spark 3.5 | Spark 4.0 | - |:---------:|:---------:|:---------:|:---------:| - | Scala 2.12 |✅ |✅ |✅ | - | Scala 2.13 |✅ |✅ |✅ | + | | Spark 3.4| Spark 3.5 | Spark 4.0 | Spark 4.1 | + |:---------:|:---------:|:---------:|:---------:|:---------:| + | Scala 2.12 |✅ |✅ |✅ | | + | Scala 2.13 |✅ |✅ |✅ |✅ | === "Sedona Python" - | | Spark 3.4 (Scala 2.12)|Spark 3.5 (Scala 2.12)| Spark 4.0 (Scala 2.12)| - |:---------:|:---------:|:---------:|:---------:| - | Python 3.7 | ✅ | ✅ | ✅ | - | Python 3.8 | ✅ | ✅ | ✅ | - | Python 3.9 | ✅ | ✅ | ✅ | - | Python 3.10 | ✅ | ✅ | ✅ | + | | Spark 3.4 (Scala 2.12)|Spark 3.5 (Scala 2.12)| Spark 4.0 (Scala 2.12)| Spark 4.1 (Scala 2.13)| + |:---------:|:---------:|:---------:|:---------:|:---------:| + | Python 3.7 | ✅ | ✅ | ✅ | ✅ | + | Python 3.8 | ✅ | ✅ | ✅ | ✅ | + | Python 3.9 | ✅ | ✅ | ✅ | ✅ | + | Python 3.10 | ✅ | ✅ | ✅ | ✅ | === "Sedona R" - | | Spark 3.4 | Spark 3.5 | Spark 4.0 | - |:---------:|:---------:|:---------:|:---------:| - | Scala 2.12 | ✅ | ✅ | ✅ | + | | Spark 3.4 | Spark 3.5 | Spark 4.0 | Spark 4.1 | + |:---------:|:---------:|:---------:|:---------:|:---------:| + | Scala 2.12 | ✅ | ✅ | ✅ | | diff --git a/pom.xml b/pom.xml index 542c6c94659..d4fe817809f 100644 --- a/pom.xml +++ b/pom.xml @@ -765,6 +765,29 @@ true + + sedona-spark-4.1 + + + spark + 4.1 + + + + 4.1.0 + 4.1 + 4 + 3.4.1 + 2.24.3 + 2.0.16 + + 2.13.17 + 2.13 + + + true + + scala2.13 diff --git a/spark/common/pom.xml b/spark/common/pom.xml index 109da51ead2..0210dd825b4 100644 --- a/spark/common/pom.xml +++ b/spark/common/pom.xml @@ -355,5 +355,22 @@ + + sedona-spark-4.1 + + + spark + 4.1 + + + + + org.apache.spark + spark-sql-api_${scala.compat.version} + ${spark.version} + provided + + + diff --git a/spark/common/src/main/java/org/apache/spark/sql/execution/datasources/geoparquet/internal/ParquetColumnVector.java b/spark/common/src/main/java/org/apache/spark/sql/execution/datasources/geoparquet/internal/ParquetColumnVector.java index 870a07b7f6a..a945f72d80e 100644 --- a/spark/common/src/main/java/org/apache/spark/sql/execution/datasources/geoparquet/internal/ParquetColumnVector.java +++ b/spark/common/src/main/java/org/apache/spark/sql/execution/datasources/geoparquet/internal/ParquetColumnVector.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.geoparquet.internal; import com.google.common.base.Preconditions; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -40,6 +41,24 @@ final class ParquetColumnVector { private final List children; private final WritableColumnVector vector; + /** + * Mark the given column vector as all-null / missing. In Spark <= 4.0 this is {@code + * setAllNull()}, in Spark >= 4.1 it was renamed to {@code setMissing()}. + */ + private static void markAllNull(WritableColumnVector v) { + try { + Method m; + try { + m = v.getClass().getMethod("setAllNull"); + } catch (NoSuchMethodException e) { + m = v.getClass().getMethod("setMissing"); + } + m.invoke(v); + } catch (Exception e) { + throw new RuntimeException("Cannot mark column vector as all null", e); + } + } + /** * Repetition & Definition levels These are allocated only for leaf columns; for non-leaf columns, * they simply maintain references to that of the former. @@ -84,7 +103,7 @@ final class ParquetColumnVector { } if (defaultValue == null) { - vector.setAllNull(); + markAllNull(vector); return; } // For Parquet tables whose columns have associated DEFAULT values, this reader must return @@ -139,7 +158,7 @@ final class ParquetColumnVector { // This can happen if all the fields of a struct are missing, in which case we should mark // the struct itself as a missing column if (allChildrenAreMissing) { - vector.setAllNull(); + markAllNull(vector); } } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 565a9e99574..76d0ec57602 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.sedona_sql.expressions.implicits._ import org.apache.spark.sql.types._ import org.locationtech.jts.algorithm.MinimumBoundingCircle import org.locationtech.jts.geom._ +import org.locationtech.jts.geom.Geometry import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ import org.apache.spark.sql.sedona_sql.expressions.LibPostalUtils.{getExpanderFromConf, getParserFromConf} import org.apache.spark.unsafe.types.UTF8String diff --git a/spark/pom.xml b/spark/pom.xml index fe35e36441f..f64cc6c74ae 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -50,6 +50,7 @@ spark-3.5 spark-4.0 + spark-4.1 diff --git a/spark/spark-4.1/.gitignore b/spark/spark-4.1/.gitignore new file mode 100644 index 00000000000..f34cc0c65b4 --- /dev/null +++ b/spark/spark-4.1/.gitignore @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +/target/ +/.settings/ +/.classpath +/.project +/dependency-reduced-pom.xml +/doc/ +/.idea/ +*.iml +/latest/ +/spark-warehouse/ +/metastore_db/ +*.log diff --git a/spark/spark-4.1/pom.xml b/spark/spark-4.1/pom.xml new file mode 100644 index 00000000000..e4d4b51f0e9 --- /dev/null +++ b/spark/spark-4.1/pom.xml @@ -0,0 +1,185 @@ + + + + 4.0.0 + + org.apache.sedona + sedona-spark-parent-${spark.compat.version}_${scala.compat.version} + 1.9.0-SNAPSHOT + ../pom.xml + + sedona-spark-4.1_${scala.compat.version} + + ${project.groupId}:${project.artifactId} + A cluster computing system for processing large-scale spatial data: SQL API for Spark 4.1. + http://sedona.apache.org/ + jar + + + false + + + + + org.apache.sedona + sedona-common + ${project.version} + + + com.fasterxml.jackson.core + * + + + it.geosolutions.jaiext.jiffle + * + + + org.codehaus.janino + * + + + + + org.apache.sedona + sedona-spark-common-${spark.compat.version}_${scala.compat.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.compat.version} + + + org.apache.spark + spark-sql_${scala.compat.version} + + + org.apache.hadoop + hadoop-client + + + org.apache.logging.log4j + log4j-1.2-api + + + org.geotools + gt-main + + + org.geotools + gt-referencing + + + org.geotools + gt-epsg-hsql + + + org.geotools + gt-geotiff + + + org.geotools + gt-coverage + + + org.geotools + gt-arcgrid + + + org.locationtech.jts + jts-core + + + org.wololo + jts2geojson + + + com.fasterxml.jackson.core + * + + + + + org.scala-lang + scala-library + + + org.scala-lang.modules + scala-collection-compat_${scala.compat.version} + + + org.scalatest + scalatest_${scala.compat.version} + + + org.mockito + mockito-inline + + + org.testcontainers + testcontainers + 2.0.2 + test + + + org.testcontainers + testcontainers-minio + 2.0.2 + test + + + io.minio + minio + + + com.squareup.okhttp3 + okhttp + + + org.apache.hadoop + hadoop-aws + ${hadoop.version} + test + + + org.apache.hadoop + hadoop-client-api + ${hadoop.version} + test + + + + src/main/scala + + + net.alchim31.maven + scala-maven-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + + diff --git a/spark/spark-4.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-4.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 00000000000..ae1de3d8bd2 --- /dev/null +++ b/spark/spark-4.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource +org.apache.sedona.sql.datasources.shapefile.ShapefileDataSource +org.apache.sedona.sql.datasources.geopackage.GeoPackageDataSource diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageDataSource.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageDataSource.scala new file mode 100644 index 00000000000..11f2db38e84 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageDataSource.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.Path +import org.apache.sedona.sql.datasources.geopackage.model.GeoPackageOptions +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.jdk.CollectionConverters._ +import scala.util.Try + +class GeoPackageDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def fallbackFileFormat: Class[_ <: FileFormat] = { + null + } + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + GeoPackageTable( + "", + sparkSession, + options, + getPaths(options), + None, + fallbackFileFormat, + getLoadOptions(options)) + } + + private def getLoadOptions(options: CaseInsensitiveStringMap): GeoPackageOptions = { + val path = options.get("path") + if (path.isEmpty) { + throw new IllegalArgumentException("GeoPackage path is not specified") + } + + val showMetadata = options.getBoolean("showMetadata", false) + val maybeTableName = options.get("tableName") + + if (!showMetadata && maybeTableName == null) { + throw new IllegalArgumentException("Table name is not specified") + } + + val tableName = if (showMetadata) { + "gpkg_contents" + } else { + maybeTableName + } + + GeoPackageOptions(tableName = tableName, showMetadata = showMetadata) + } + + override def shortName(): String = "geopackage" +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReader.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReader.scala new file mode 100644 index 00000000000..4e59163922d --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReader.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.Path +import org.apache.sedona.sql.datasources.geopackage.connection.{FileSystemUtils, GeoPackageConnectionManager} +import org.apache.sedona.sql.datasources.geopackage.model.TableType.{FEATURES, METADATA, TILES, UNKNOWN} +import org.apache.sedona.sql.datasources.geopackage.model.{GeoPackageReadOptions, PartitionOptions, TileRowMetadata} +import org.apache.sedona.sql.datasources.geopackage.transform.ValuesMapper +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.util.SerializableConfiguration + +import java.io.File +import java.sql.ResultSet + +case class GeoPackagePartitionReader( + var rs: ResultSet, + options: GeoPackageReadOptions, + broadcastedConf: Broadcast[SerializableConfiguration], + var currentTempFile: File, + copying: Boolean = false) + extends PartitionReader[InternalRow] { + + private var values: Seq[Any] = Seq.empty + private var currentFile = options.currentFile + private val partitionedFiles = options.partitionedFiles + + override def next(): Boolean = { + if (rs.next()) { + values = ValuesMapper.mapValues(adjustPartitionOptions, rs) + return true + } + + partitionedFiles.remove(currentFile) + + if (partitionedFiles.isEmpty) { + return false + } + + rs.close() + + currentFile = partitionedFiles.head + val (tempFile, _) = FileSystemUtils.copyToLocal( + options = broadcastedConf.value.value, + file = new Path(currentFile.filePath.toString())) + + if (copying) { + currentTempFile.deleteOnExit() + } + + currentTempFile = tempFile + + rs = GeoPackageConnectionManager.getTableCursor(currentTempFile.getPath, options.tableName) + + if (!rs.next()) { + return false + } + + values = ValuesMapper.mapValues(adjustPartitionOptions, rs) + + true + } + + private def adjustPartitionOptions: PartitionOptions = { + options.partitionOptions.tableType match { + case FEATURES | METADATA => options.partitionOptions + case TILES => + val tileRowMetadata = TileRowMetadata( + zoomLevel = rs.getInt("zoom_level"), + tileColumn = rs.getInt("tile_column"), + tileRow = rs.getInt("tile_row")) + + options.partitionOptions.withTileRowMetadata(tileRowMetadata) + case UNKNOWN => options.partitionOptions + } + + } + + override def get(): InternalRow = { + InternalRow.fromSeq(values) + } + + override def close(): Unit = { + rs.close() + if (copying) { + options.tempFile.delete() + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReaderFactory.scala new file mode 100644 index 00000000000..b1d38996b04 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackagePartitionReaderFactory.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.Path +import org.apache.sedona.sql.datasources.geopackage.connection.{FileSystemUtils, GeoPackageConnectionManager} +import org.apache.sedona.sql.datasources.geopackage.model.TableType.TILES +import org.apache.sedona.sql.datasources.geopackage.model.{GeoPackageOptions, GeoPackageReadOptions, PartitionOptions, TableType} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class GeoPackagePartitionReaderFactory( + sparkSession: SparkSession, + broadcastedConf: Broadcast[SerializableConfiguration], + loadOptions: GeoPackageOptions, + dataSchema: StructType) + extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val partitionFiles = partition match { + case filePartition: FilePartition => filePartition.files + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + val (tempFile, copied) = FileSystemUtils.copyToLocal( + options = broadcastedConf.value.value, + file = new Path(partitionFiles.head.filePath.toString())) + + val tableType = if (loadOptions.showMetadata) { + TableType.METADATA + } else { + GeoPackageConnectionManager.findFeatureMetadata(tempFile.getPath, loadOptions.tableName) + } + + val rs = + GeoPackageConnectionManager.getTableCursor(tempFile.getAbsolutePath, loadOptions.tableName) + + val schema = GeoPackageConnectionManager.getSchema(tempFile.getPath, loadOptions.tableName) + + if (StructType(schema.map(_.toStructField(tableType))) != dataSchema) { + throw new IllegalArgumentException( + s"Schema mismatch: expected $dataSchema, got ${StructType(schema.map(_.toStructField(tableType)))}") + } + + val tileMetadata = tableType match { + case TILES => + Some( + GeoPackageConnectionManager.findTilesMetadata(tempFile.getPath, loadOptions.tableName)) + case _ => None + } + + GeoPackagePartitionReader( + rs = rs, + options = GeoPackageReadOptions( + tableName = loadOptions.tableName, + tempFile = tempFile, + partitionOptions = + PartitionOptions(tableType = tableType, columns = schema, tile = tileMetadata), + partitionedFiles = scala.collection.mutable.HashSet(partitionFiles: _*), + currentFile = partitionFiles.head), + broadcastedConf = broadcastedConf, + currentTempFile = tempFile, + copying = copied) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScan.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScan.scala new file mode 100644 index 00000000000..768afd79173 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScan.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.sedona.sql.datasources.geopackage.model.GeoPackageOptions +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.jdk.CollectionConverters._ + +case class GeoPackageScan( + dataSchema: StructType, + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + loadOptions: GeoPackageOptions) + extends FileScan { + + override def partitionFilters: Seq[Expression] = { + Seq.empty + } + + override def dataFilters: Seq[Expression] = { + Seq.empty + } + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + GeoPackagePartitionReaderFactory(sparkSession, broadcastedConf, loadOptions, dataSchema) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScanBuilder.scala new file mode 100644 index 00000000000..829bd9c2201 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageScanBuilder.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.sedona.sql.datasources.geopackage.model.GeoPackageOptions +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.{InMemoryFileIndex, PartitioningAwareFileIndex} +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import scala.jdk.CollectionConverters._ + +class GeoPackageScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + options: CaseInsensitiveStringMap, + loadOptions: GeoPackageOptions, + userDefinedSchema: Option[StructType] = None) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + val fileIndexAdjusted = + if (loadOptions.showMetadata) + new InMemoryFileIndex( + sparkSession, + fileIndex.inputFiles.slice(0, 1).map(new org.apache.hadoop.fs.Path(_)), + options.asCaseSensitiveMap.asScala.toMap, + userDefinedSchema) + else fileIndex + + GeoPackageScan( + dataSchema, + sparkSession, + fileIndexAdjusted, + dataSchema, + readPartitionSchema(), + options, + loadOptions) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageTable.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageTable.scala new file mode 100644 index 00000000000..85dec8427e4 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/geopackage/GeoPackageTable.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.sql.datasources.geopackage.connection.{FileSystemUtils, GeoPackageConnectionManager} +import org.apache.sedona.sql.datasources.geopackage.model.{GeoPackageOptions, MetadataSchema, TableType} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.jdk.CollectionConverters._ + +case class GeoPackageTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat], + loadOptions: GeoPackageOptions) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (loadOptions.showMetadata) { + return MetadataSchema.schema + } + + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + + val (tempFile, copied) = + FileSystemUtils.copyToLocal(serializableConf.value, files.head.getPath) + + if (copied) { + tempFile.deleteOnExit() + } + + val tableType = if (loadOptions.showMetadata) { + TableType.METADATA + } else { + GeoPackageConnectionManager.findFeatureMetadata(tempFile.getPath, loadOptions.tableName) + } + + Some( + StructType( + GeoPackageConnectionManager + .getSchema(tempFile.getPath, loadOptions.tableName) + .map(field => field.toStructField(tableType)))) + } + + override def formatName: String = { + "GeoPackage" + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new GeoPackageScanBuilder( + sparkSession, + fileIndex, + schema, + options, + loadOptions, + userSpecifiedSchema) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + null + } + +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala new file mode 100644 index 00000000000..7cd6d03a6d9 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +/** + * A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the + * following components of shapefiles: + * + * + * + *

The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala new file mode 100644 index 00000000000..306b1df4f6c --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.Partition +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.execution.datasources.PartitionedFile + +case class ShapefilePartition(index: Int, files: Array[PartitionedFile]) + extends Partition + with InputPartition diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 00000000000..301d63296fb --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = file.filePath.toPath.getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> file.filePath.toPath + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 00000000000..5a28af6d66f --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: ShapefilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 00000000000..ebc02fae85a --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 00000000000..afdface8942 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = partitionedFile.filePath.toPath + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(ShapefilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 00000000000..e5135e381df --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 00000000000..7db6bb8d1f4 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 00000000000..04ac3bdff9b --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis.Resolver +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromString(UTF8String.fromBytes(bytes)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala new file mode 100644 index 00000000000..2bdd92bd64a --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.parser + +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.execution.SparkSqlAstBuilder +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.DataType + +class SedonaSqlAstBuilder extends SparkSqlAstBuilder { + + /** + * Override the method to handle the geometry data type + * @param ctx + * @return + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { + ctx.getText.toUpperCase() match { + case "GEOMETRY" => GeometryUDT + case _ => super.visitPrimitiveDataType(ctx) + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala new file mode 100644 index 00000000000..54fb074eb05 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.parser + +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkSqlParser + +class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { + + // The parser builder for the Sedona SQL AST + val parserBuilder = new SedonaSqlAstBuilder + + /** + * Parse the SQL text and return the logical plan. This method first attempts to use the + * delegate parser to parse the SQL text. If the delegate parser fails (throws an exception), it + * falls back to using the Sedona SQL parser. + * + * @param sqlText + * The SQL text to be parsed. + * @return + * The parsed logical plan. + */ + override def parsePlan(sqlText: String): LogicalPlan = + try { + delegate.parsePlan(sqlText) + } catch { + case _: Exception => + parse(sqlText) { parser => + parserBuilder.visit(parser.singleStatement()) + }.asInstanceOf[LogicalPlan] + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala new file mode 100644 index 00000000000..43e1ababb7d --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Data source for reading GeoParquet metadata. This could be accessed using the `spark.read` + * interface: + * {{{ + * val df = spark.read.format("geoparquet.metadata").load("path/to/geoparquet") + * }}} + */ +class GeoParquetMetadataDataSource extends FileDataSourceV2 with DataSourceRegister { + override val shortName: String = "geoparquet.metadata" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + None, + fallbackFileFormat) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala new file mode 100644 index 00000000000..683160e93b2 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.ParquetReadOptions +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods.{compact, render} + +case class GeoParquetMetadataPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: FileSourceOptions, + filters: Seq[Filter]) + extends FilePartitionReaderFactory { + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { + val iter = GeoParquetMetadataPartitionReaderFactory.readFile( + broadcastedConf.value.value, + partitionedFile, + readDataSchema) + val fileReader = new PartitionReaderFromIterator[InternalRow](iter) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFile.partitionValues) + } +} + +object GeoParquetMetadataPartitionReaderFactory { + private def readFile( + configuration: Configuration, + partitionedFile: PartitionedFile, + readDataSchema: StructType): Iterator[InternalRow] = { + + val inputFile = HadoopInputFile.fromPath(partitionedFile.toPath, configuration) + val inputStream = inputFile.newStream() + + val footer = ParquetFileReader + .readFooter(inputFile, ParquetReadOptions.builder().build(), inputStream) + + val filePath = partitionedFile.toPath.toString + val metadata = footer.getFileMetaData.getKeyValueMetaData + val row = GeoParquetMetaData.parseKeyValueMetaData(metadata) match { + case Some(geo) => + val geoColumnsMap = geo.columns.map { case (columnName, columnMetadata) => + implicit val formats: org.json4s.Formats = DefaultFormats + import org.json4s.jackson.Serialization + val columnMetadataFields: Array[Any] = Array( + UTF8String.fromString(columnMetadata.encoding), + new GenericArrayData(columnMetadata.geometryTypes.map(UTF8String.fromString).toArray), + new GenericArrayData(columnMetadata.bbox.toArray), + columnMetadata.crs + .map(projjson => UTF8String.fromString(compact(render(projjson)))) + .getOrElse(UTF8String.fromString("")), + columnMetadata.covering + .map(covering => UTF8String.fromString(Serialization.write(covering))) + .orNull) + val columnMetadataStruct = new GenericInternalRow(columnMetadataFields) + UTF8String.fromString(columnName) -> columnMetadataStruct + } + val fields: Array[Any] = Array( + UTF8String.fromString(filePath), + UTF8String.fromString(geo.version.orNull), + UTF8String.fromString(geo.primaryColumn), + ArrayBasedMapData(geoColumnsMap)) + new GenericInternalRow(fields) + case None => + // Not a GeoParquet file, return a row with null metadata values. + val fields: Array[Any] = Array(UTF8String.fromString(filePath), null, null, null) + new GenericInternalRow(fields) + } + Iterator(pruneBySchema(row, GeoParquetMetadataTable.schema, readDataSchema)) + } + + private def pruneBySchema( + row: InternalRow, + schema: StructType, + readDataSchema: StructType): InternalRow = { + // Projection push down for nested fields is not enabled, so this very simple implementation is enough. + val values: Array[Any] = readDataSchema.fields.map { field => + val index = schema.fieldIndex(field.name) + row.get(index, field.dataType) + } + new GenericInternalRow(values) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala new file mode 100644 index 00000000000..d7719d87dad --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.collection.JavaConverters._ + +case class GeoParquetMetadataScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + val fileSourceOptions = new FileSourceOptions(caseSensitiveMap) + GeoParquetMetadataPartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + fileSourceOptions, + pushedFilters) + } + + override def isSplitable(path: Path): Boolean = false + + override def getFileUnSplittableReason(path: Path): String = + "Reading parquet file metadata does not require splitting the file" +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala new file mode 100644 index 00000000000..c60369e1087 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class GeoParquetMetadataScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + override def build(): Scan = { + GeoParquetMetadataScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala new file mode 100644 index 00000000000..845764fae55 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class GeoParquetMetadataTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def formatName: String = "GeoParquet Metadata" + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + Some(GeoParquetMetadataTable.schema) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new GeoParquetMetadataScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) +} + +object GeoParquetMetadataTable { + private val columnMetadataType = StructType( + Seq( + StructField("encoding", StringType, nullable = true), + StructField("geometry_types", ArrayType(StringType), nullable = true), + StructField("bbox", ArrayType(DoubleType), nullable = true), + StructField("crs", StringType, nullable = true), + StructField("covering", StringType, nullable = true))) + + private val columnsType = MapType(StringType, columnMetadataType, valueContainsNull = false) + + val schema: StructType = StructType( + Seq( + StructField("path", StringType, nullable = false), + StructField("version", StringType, nullable = true), + StructField("primary_column", StringType, nullable = true), + StructField("columns", columnsType, nullable = true))) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala new file mode 100644 index 00000000000..30a3e652cf2 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import scala.jdk.CollectionConverters._ + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata +import org.apache.spark.sql.types.StructType + +/** + * A physical plan that evaluates a [[PythonUDF]]. + */ +case class SedonaArrowEvalPythonExec( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: SparkPlan, + evalType: Int) + extends EvalPythonExec + with PythonSQLMetrics { + + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + private[this] val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } + + override protected def evaluatorFactory: EvalPythonEvaluatorFactory = { + new SedonaArrowEvalPythonEvaluatorFactory( + child.output, + udfs, + output, + conf.arrowMaxRecordsPerBatch, + evalType, + conf.sessionLocalTimeZone, + conf.arrowUseLargeVarTypes, + ArrowPythonRunner.getPythonRunnerConfMap(conf), + pythonMetrics, + jobArtifactUUID, + sessionUUID, + conf.pythonUDFProfiler) + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + +class SedonaArrowEvalPythonEvaluatorFactory( + childOutput: Seq[Attribute], + udfs: Seq[PythonUDF], + output: Seq[Attribute], + batchSize: Int, + evalType: Int, + sessionLocalTimeZone: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + sessionUUID: Option[String], + profiler: Option[String]) + extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { + + override def evaluate( + funcs: Seq[(ChainedPythonFunctions, Long)], + argMetas: Array[Array[ArgumentMetadata]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + val batchIter = Iterator(iter) + + val pyRunner = new ArrowPythonWithNamedArgumentRunner( + funcs, + evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, + argMetas, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + sessionUUID, + profiler) with BatchedPythonArrowInput + val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context) + + columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala new file mode 100644 index 00000000000..3d3301580cc --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionSet, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Subquery} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_UDF + +import scala.collection.mutable + +// That rule extracts scalar Python UDFs, currently Apache Spark has +// assert on types which blocks using the vectorized udfs with geometry type +class ExtractSedonaUDFRule extends Rule[LogicalPlan] with Logging { + + private def hasScalarPythonUDF(e: Expression): Boolean = { + e.exists(PythonUDF.isScalarPythonUDF) + } + + @scala.annotation.tailrec + private def canEvaluateInPython(e: PythonUDF): Boolean = { + e.children match { + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) + case children => !children.exists(hasScalarPythonUDF) + } + } + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && e + .asInstanceOf[PythonUDF] + .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF + } + + private def collectEvaluableUDFsFromExpressions( + expressions: Seq[Expression]): Seq[PythonUDF] = { + + var firstVisitedScalarUDFEvalType: Option[Int] = None + + def canChainUDF(evalType: Int): Boolean = { + evalType == firstVisitedScalarUDFEvalType.get + } + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF + if isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && firstVisitedScalarUDFEvalType.isEmpty => + firstVisitedScalarUDFEvalType = Some(udf.evalType) + Seq(udf) + case udf: PythonUDF + if isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && canChainUDF(udf.evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs) + } + + expressions.flatMap(collectEvaluableUDFs) + } + + private var hasFailedBefore: Boolean = false + + def apply(plan: LogicalPlan): LogicalPlan = plan match { + case s: Subquery if s.correlated => plan + + case _ => + try { + plan.transformUpWithPruning(_.containsPattern(PYTHON_UDF)) { + case p: SedonaArrowEvalPython => p + + case plan: LogicalPlan => extract(plan) + } + } catch { + case e: Throwable => + if (!hasFailedBefore) { + log.warn( + s"Vectorized UDF feature won't be available due to plan transformation error.") + log.warn( + s"Failed to extract Sedona UDFs from plan: ${plan.treeString}\n" + + s"Exception: ${e.getMessage}", + e) + hasFailedBefore = true + } + plan + } + } + + private def canonicalizeDeterministic(u: PythonUDF) = { + if (u.deterministic) { + u.canonicalized.asInstanceOf[PythonUDF] + } else { + u + } + } + + private def extract(plan: LogicalPlan): LogicalPlan = { + val udfs = ExpressionSet(collectEvaluableUDFsFromExpressions(plan.expressions)) + .filter(udf => udf.references.subsetOf(plan.inputSet)) + .toSeq + .asInstanceOf[Seq[PythonUDF]] + + udfs match { + case Seq() => plan + case _ => resolveUDFs(plan, udfs) + } + } + + def resolveUDFs(plan: LogicalPlan, udfs: Seq[PythonUDF]): LogicalPlan = { + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + + val newChildren = adjustAttributeMap(plan, udfs, attributeMap) + + udfs.map(canonicalizeDeterministic).filterNot(attributeMap.contains).foreach { udf => + throw new IllegalStateException( + s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } + + val rewritten = plan.withNewChildren(newChildren).transformExpressions { case p: PythonUDF => + attributeMap.getOrElse(canonicalizeDeterministic(p), p) + } + + val newPlan = extract(rewritten) + if (newPlan.output != plan.output) { + Project(plan.output, newPlan) + } else { + newPlan + } + } + + def adjustAttributeMap( + plan: LogicalPlan, + udfs: Seq[PythonUDF], + attributeMap: mutable.HashMap[PythonUDF, Expression]): Seq[LogicalPlan] = { + plan.children.map { child => + val validUdfs = udfs.filter { udf => + udf.references.subsetOf(child.outputSet) + } + + if (validUdfs.nonEmpty) { + require( + validUdfs.forall(isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") + + val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() + } + + val evalTypes = validUdfs.map(_.evalType).toSet + if (evalTypes.size != 1) { + throw new IllegalStateException( + "Expected udfs have the same evalType but got different evalTypes: " + + evalTypes.mkString(",")) + } + val evalType = evalTypes.head + val evaluation = evalType match { + case PythonEvalType.SQL_SCALAR_SEDONA_UDF => + SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType) + case _ => + throw new IllegalStateException("Unexpected UDF evalType") + } + + attributeMap ++= validUdfs.map(canonicalizeDeterministic).zip(resultAttrs) + evaluation + } else { + child + } + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala new file mode 100644 index 00000000000..7600ece5079 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython, LogicalPlan} + +case class SedonaArrowEvalPython( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: LogicalPlan, + evalType: Int) + extends BaseEvalPython { + override protected def withNewChildInternal(newChild: LogicalPlan): SedonaArrowEvalPython = + copy(child = newChild) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala new file mode 100644 index 00000000000..21e389ee567 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.SedonaArrowEvalPythonExec +import org.apache.spark.sql.types.StructType + +import scala.collection.JavaConverters.asScalaIteratorConverter + +// We use custom Strategy to avoid Apache Spark assert on types, we +// can consider extending this to support other engines working with +// arrow data +class SedonaArrowStrategy extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case SedonaArrowEvalPython(udfs, output, child, evalType) => + SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil + case _ => Nil + } +} diff --git a/spark/spark-4.1/src/test/resources/log4j2.properties b/spark/spark-4.1/src/test/resources/log4j2.properties new file mode 100644 index 00000000000..683ecd32f25 --- /dev/null +++ b/spark/spark-4.1/src/test/resources/log4j2.properties @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Set everything to be logged to the file target/unit-tests.log +rootLogger.level = info +rootLogger.appenderRef.file.ref = File + +appender.file.type = File +appender.file.name = File +appender.file.fileName = target/unit-tests.log +appender.file.append = true +appender.file.layout.type = PatternLayout +appender.file.layout.pattern = %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex + +# Ignore messages below warning level from Jetty, because it's a bit verbose +logger.jetty.name = org.sparkproject.jetty +logger.jetty.level = warn diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala new file mode 100644 index 00000000000..6d9f41bf4e3 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import io.minio.{MakeBucketArgs, MinioClient, PutObjectArgs} +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{BinaryType, BooleanType, DateType, DoubleType, IntegerType, StringType, StructField, StructType, TimestampType} +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.testcontainers.containers.MinIOContainer + +import java.io.FileInputStream +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +class GeoPackageReaderTest extends TestBaseScala with Matchers { + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + import sparkSession.implicits._ + + val path: String = resourceFolder + "geopackage/example.gpkg" + val polygonsPath: String = resourceFolder + "geopackage/features.gpkg" + val rasterPath: String = resourceFolder + "geopackage/raster.gpkg" + val wktReader = new org.locationtech.jts.io.WKTReader() + val wktWriter = new org.locationtech.jts.io.WKTWriter() + + val expectedFeatureSchema = StructType( + Seq( + StructField("id", IntegerType, true), + StructField("geometry", GeometryUDT, true), + StructField("text", StringType, true), + StructField("real", DoubleType, true), + StructField("boolean", BooleanType, true), + StructField("blob", BinaryType, true), + StructField("integer", IntegerType, true), + StructField("text_limited", StringType, true), + StructField("blob_limited", BinaryType, true), + StructField("date", DateType, true), + StructField("datetime", TimestampType, true))) + + describe("Reading GeoPackage metadata") { + it("should read GeoPackage metadata") { + val df = sparkSession.read + .format("geopackage") + .option("showMetadata", "true") + .load(path) + + df.where("data_type = 'tiles'").show(false) + + df.count shouldEqual 34 + } + } + + describe("Reading Vector data") { + it("should read GeoPackage - point1") { + val df = readFeatureData("point1") + df.schema shouldEqual expectedFeatureSchema + + df.count() shouldEqual 4 + + val firstElement = df.collectAsList().get(0).toSeq + + val expectedValues = Seq( + 1, + wktReader.read(POINT_1), + "BIT Systems", + 4519.866024037493, + true, + Array(48, 99, 57, 54, 49, 56, 55, 54, 45, 98, 102, 100, 52, 45, 52, 102, 52, 48, 45, 97, + 49, 102, 101, 45, 55, 49, 55, 101, 57, 100, 50, 98, 48, 55, 98, 101), + 3, + "bcd5a36f-16dc-4385-87be-b40353848597", + Array(49, 50, 53, 50, 97, 99, 98, 52, 45, 57, 54, 54, 52, 45, 52, 101, 51, 50, 45, 57, 54, + 100, 101, 45, 56, 48, 54, 101, 101, 48, 101, 101, 49, 102, 57, 48), + Date.valueOf("2023-09-19"), + Timestamp.valueOf("2023-09-19 11:24:15.695")) + + firstElement should contain theSameElementsAs expectedValues + } + + it("should read GeoPackage - line1") { + val df = readFeatureData("line1") + .withColumn("datetime", expr("from_utc_timestamp(datetime, 'UTC')")) + + df.schema shouldEqual expectedFeatureSchema + + df.count() shouldEqual 3 + + val firstElement = df.collectAsList().get(0).toSeq + + firstElement should contain theSameElementsAs Seq( + 1, + wktReader.read(LINESTRING_1), + "East Lockheed Drive", + 1990.5159635296877, + false, + Array(54, 97, 98, 100, 98, 51, 97, 56, 45, 54, 53, 101, 48, 45, 52, 55, 48, 54, 45, 56, + 50, 52, 48, 45, 51, 57, 48, 55, 99, 50, 102, 102, 57, 48, 99, 55), + 1, + "13dd91dc-3b7d-4d8d-a0ca-b3afb8e31c3d", + Array(57, 54, 98, 102, 56, 99, 101, 56, 45, 102, 48, 54, 49, 45, 52, 55, 99, 48, 45, 97, + 98, 48, 101, 45, 97, 99, 50, 52, 100, 98, 50, 97, 102, 50, 50, 54), + Date.valueOf("2023-09-19"), + Timestamp.valueOf("2023-09-19 11:24:15.716")) + } + + it("should read GeoPackage - polygon1") { + val df = readFeatureData("polygon1") + df.count shouldEqual 3 + df.schema shouldEqual expectedFeatureSchema + + df.select("geometry").collectAsList().get(0).toSeq should contain theSameElementsAs Seq( + wktReader.read(POLYGON_1)) + } + + it("should read GeoPackage - geometry1") { + val df = readFeatureData("geometry1") + df.count shouldEqual 10 + df.schema shouldEqual expectedFeatureSchema + + df.selectExpr("ST_ASTEXT(geometry)") + .as[String] + .collect() should contain theSameElementsAs Seq( + POINT_1, + POINT_2, + POINT_3, + POINT_4, + LINESTRING_1, + LINESTRING_2, + LINESTRING_3, + POLYGON_1, + POLYGON_2, + POLYGON_3) + } + + it("should read polygon with envelope data") { + val tables = Table( + ("tableName", "expectedCount"), + ("GB_Hex_5km_GS_CompressibleGround_v8", 4233), + ("GB_Hex_5km_GS_Landslides_v8", 4228), + ("GB_Hex_5km_GS_RunningSand_v8", 4233), + ("GB_Hex_5km_GS_ShrinkSwell_v8", 4233), + ("GB_Hex_5km_GS_SolubleRocks_v8", 4295)) + + forAll(tables) { (tableName: String, expectedCount: Int) => + val df = sparkSession.read + .format("geopackage") + .option("tableName", tableName) + .load(polygonsPath) + + df.count() shouldEqual expectedCount + } + } + + it("should handle datetime fields without timezone information") { + // This test verifies the fix for DateTimeParseException when reading + // GeoPackage files with datetime fields that don't include timezone info + val testFilePath = resourceFolder + "geopackage/test_datetime_issue.gpkg" + + // Test reading the test_features table with problematic datetime formats + val df = sparkSession.read + .format("geopackage") + .option("tableName", "test_features") + .load(testFilePath) + + // The test should not throw DateTimeParseException when reading datetime fields + noException should be thrownBy { + df.select("created_at", "updated_at").collect() + } + + // Verify that datetime fields are properly parsed as TimestampType + df.schema.fields.find(_.name == "created_at").get.dataType shouldEqual TimestampType + df.schema.fields.find(_.name == "updated_at").get.dataType shouldEqual TimestampType + + // Verify that we can read the datetime values + val datetimeValues = df.select("created_at", "updated_at").collect() + datetimeValues should not be empty + + // Verify that datetime values are valid timestamps + datetimeValues.foreach { row => + val createdTimestamp = row.getAs[Timestamp]("created_at") + val updatedTimestamp = row.getAs[Timestamp]("updated_at") + createdTimestamp should not be null + updatedTimestamp should not be null + createdTimestamp.getTime should be > 0L + updatedTimestamp.getTime should be > 0L + } + + // Test showMetadata option with the same file + noException should be thrownBy { + val metadataDf = sparkSession.read + .format("geopackage") + .option("showMetadata", "true") + .load(testFilePath) + metadataDf.select("last_change").collect() + } + } + } + + describe("GeoPackage Raster Data Test") { + it("should read") { + val fractions = + Table( + ("tableName", "channelNumber", "expectedSum"), + ("point1_tiles", 4, 466591.0), + ("line1_tiles", 4, 5775976.0), + ("polygon1_tiles", 4, 1.1269871e7), + ("geometry1_tiles", 4, 2.6328442e7), + ("point2_tiles", 4, 137456.0), + ("line2_tiles", 4, 6701101.0), + ("polygon2_tiles", 4, 5.1170714e7), + ("geometry2_tiles", 4, 1.6699823e7), + ("bit_systems", 1, 6.5561879e7), + ("nga", 1, 6.8078856e7), + ("bit_systems_wgs84", 1, 7.7276934e7), + ("nga_pc", 1, 2.90590616e8), + ("bit_systems_world", 1, 7.7276934e7), + ("nga_pc_world", 1, 2.90590616e8)) + + forAll(fractions) { (tableName: String, channelNumber: Int, expectedSum: Double) => + { + val df = readFeatureData(tableName) + val calculatedSum = df + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${channelNumber}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum.collect().head shouldEqual expectedSum + } + } + } + + it("should be able to read complex raster data") { + val df = sparkSession.read + .format("geopackage") + .option("tableName", "AuroraAirportNoise") + .load(rasterPath) + + df.show(5) + + val calculatedSum = df + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${1}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum.first() shouldEqual 2.027126e7 + + val df2 = sparkSession.read + .format("geopackage") + .option("tableName", "LiquorLicenseDensity") + .load(rasterPath) + + val calculatedSum2 = df2 + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${1}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum2.first() shouldEqual 2.882028e7 + } + + } + + describe("Reading from S3") { + it("should be able to read files from S3") { + val container = new MinIOContainer("minio/minio:latest") + + container.start() + + val minioClient = createMinioClient(container) + val makeBucketRequest = MakeBucketArgs + .builder() + .bucket("sedona") + .build() + + minioClient.makeBucket(makeBucketRequest) + + adjustSparkSession(sparkSessionMinio, container) + + val inputPath: String = prepareFile("example.geopackage", path, minioClient) + + sparkSessionMinio.read + .format("geopackage") + .option("showMetadata", "true") + .load(inputPath) + .count shouldEqual 34 + + val df = sparkSession.read + .format("geopackage") + .option("tableName", "point1") + .load(inputPath) + + df.count shouldEqual 4 + + val inputPathLarger: String = prepareFiles((1 to 300).map(_ => path).toArray, minioClient) + + val dfLarger = sparkSessionMinio.read + .format("geopackage") + .option("tableName", "point1") + .load(inputPathLarger) + + dfLarger.count shouldEqual 300 * 4 + + container.stop() + } + } + + private def readFeatureData(tableName: String): DataFrame = { + sparkSession.read + .format("geopackage") + .option("tableName", tableName) + .load(path) + } + + private def prepareFiles(paths: Array[String], minioClient: MinioClient): String = { + val key = "geopackage" + + paths.foreach(path => { + val fis = new FileInputStream(path); + putFileIntoBucket( + "sedona", + s"${key}/${scala.util.Random.nextInt(1000000000)}.geopackage", + fis, + minioClient) + }) + + s"s3a://sedona/$key" + } + + private def prepareFile(name: String, path: String, minioClient: MinioClient): String = { + val fis = new FileInputStream(path); + putFileIntoBucket("sedona", name, fis, minioClient) + + s"s3a://sedona/$name" + } + + private val POINT_1 = "POINT (-104.801918 39.720014)" + private val POINT_2 = "POINT (-104.802987 39.717703)" + private val POINT_3 = "POINT (-104.807496 39.714085)" + private val POINT_4 = "POINT (-104.79948 39.714729)" + private val LINESTRING_1 = + "LINESTRING (-104.800614 39.720721, -104.802174 39.720726, -104.802584 39.72066, -104.803088 39.720477, -104.803474 39.720209)" + private val LINESTRING_2 = + "LINESTRING (-104.809612 39.718379, -104.806638 39.718372, -104.806236 39.718439, -104.805939 39.718536, -104.805654 39.718677, -104.803652 39.720095)" + private val LINESTRING_3 = + "LINESTRING (-104.806344 39.722425, -104.805854 39.722634, -104.805656 39.722647, -104.803749 39.722641, -104.803769 39.721849, -104.803806 39.721725, -104.804382 39.720865)" + private val POLYGON_1 = + "POLYGON ((-104.802246 39.720343, -104.802246 39.719753, -104.802183 39.719754, -104.802184 39.719719, -104.802138 39.719694, -104.802097 39.719691, -104.802096 39.719648, -104.801646 39.719648, -104.801644 39.719722, -104.80155 39.719723, -104.801549 39.720207, -104.801648 39.720207, -104.801648 39.720341, -104.802246 39.720343))" + private val POLYGON_2 = + "POLYGON ((-104.802259 39.719604, -104.80226 39.71955, -104.802281 39.719416, -104.802332 39.719372, -104.802081 39.71924, -104.802044 39.71929, -104.802027 39.719278, -104.802044 39.719229, -104.801785 39.719129, -104.801639 39.719413, -104.801649 39.719472, -104.801694 39.719524, -104.801753 39.71955, -104.80175 39.719606, -104.80194 39.719606, -104.801939 39.719555, -104.801977 39.719556, -104.801979 39.719606, -104.802259 39.719604), (-104.80213 39.71944, -104.802133 39.71949, -104.802148 39.71949, -104.80218 39.719473, -104.802187 39.719456, -104.802182 39.719439, -104.802088 39.719387, -104.802047 39.719427, -104.801858 39.719342, -104.801883 39.719294, -104.801832 39.719284, -104.801787 39.719298, -104.801763 39.719331, -104.801823 39.719352, -104.80179 39.71942, -104.801722 39.719404, -104.801715 39.719445, -104.801748 39.719484, -104.801809 39.719494, -104.801816 39.719439, -104.80213 39.71944))" + private val POLYGON_3 = + "POLYGON ((-104.802867 39.718122, -104.802369 39.717845, -104.802571 39.71763, -104.803066 39.717909, -104.802867 39.718122))" +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala new file mode 100644 index 00000000000..421890c7001 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.spark.sql.Row +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.scalatest.BeforeAndAfterAll + +import java.util.Collections +import scala.collection.JavaConverters._ + +class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation: String = resourceFolder + "geoparquet/" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + describe("GeoParquet Metadata tests") { + it("Reading GeoParquet Metadata") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.collect() + assert(metadataArray.length > 1) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("version") == "1.0.0-dev")) + assert(metadataArray.exists(_.getAs[String]("primary_column") == "geometry")) + assert(metadataArray.exists { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + columnsMap != null && columnsMap + .containsKey("geometry") && columnsMap.get("geometry").isInstanceOf[Row] + }) + assert(metadataArray.forall { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + if (columnsMap == null || !columnsMap.containsKey("geometry")) true + else { + val columnMetadata = columnsMap.get("geometry").asInstanceOf[Row] + columnMetadata.getAs[String]("encoding") == "WKB" && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("bbox")) + .asScala + .forall(_.isInstanceOf[Double]) && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("geometry_types")) + .asScala + .forall(_.isInstanceOf[String]) && + columnMetadata.getAs[String]("crs").nonEmpty && + columnMetadata.getAs[String]("crs") != "null" + } + }) + } + + it("Reading GeoParquet Metadata with column pruning") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df + .selectExpr("path", "substring(primary_column, 1, 2) AS partial_primary_column") + .collect() + assert(metadataArray.length > 1) + assert(metadataArray.forall(_.length == 2)) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("partial_primary_column") == "ge")) + } + + it("Reading GeoParquet Metadata of plain parquet files") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.where("path LIKE '%plain.parquet'").collect() + assert(metadataArray.nonEmpty) + assert(metadataArray.forall(_.getAs[String]("path").endsWith("plain.parquet"))) + assert(metadataArray.forall(_.getAs[String]("version") == null)) + assert(metadataArray.forall(_.getAs[String]("primary_column") == null)) + assert(metadataArray.forall(_.getAs[String]("columns") == null)) + } + + it("Read GeoParquet without CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "") + } + + it("Read GeoParquet with null CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "null") + } + + it("Read GeoParquet with snake_case geometry column name and camelCase column name") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column_1", GeometryUDT, nullable = false), + StructField("geomColumn2", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")) + assert(metadata.containsKey("geom_column_1")) + assert(!metadata.containsKey("geoColumn1")) + assert(metadata.containsKey("geomColumn2")) + assert(!metadata.containsKey("geom_column2")) + assert(!metadata.containsKey("geom_column_2")) + } + + it("Read GeoParquet with covering metadata") { + val dfMeta = sparkSession.read + .format("geoparquet.metadata") + .load(geoparquetdatalocation + "/example-1.1.0.parquet") + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + val covering = metadata.getAs[String]("covering") + assert(covering.nonEmpty) + Seq("bbox", "xmin", "ymin", "xmax", "ymax").foreach { key => + assert(covering contains key) + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala new file mode 100644 index 00000000000..6f873d0a087 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.scalatest.matchers.must.Matchers.be +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.prop.TableDrivenPropertyChecks + +/** + * Test suite for testing Sedona SQL support. + */ +class SQLSyntaxTestScala extends TestBaseScala with TableDrivenPropertyChecks { + + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession.conf.set("spark.sql.legacy.createHiveTableByDefault", "false") + } + + describe("Table creation DDL tests") { + + it("should be able to create a regular table without geometry column should work") { + sparkSession.sql("DROP TABLE IF EXISTS T_TEST_REGULAR") + sparkSession.sql("CREATE TABLE IF NOT EXISTS T_TEST_REGULAR (INT_COL INT)") + sparkSession.catalog.tableExists("T_TEST_REGULAR") should be(true) + sparkSession.sql("DROP TABLE IF EXISTS T_TEST_REGULAR") + sparkSession.catalog.tableExists("T_TEST_REGULAR") should be(false) + } + + it( + "should be able to create a regular table with geometry column should work without a workaround") { + try { + sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } + } + + it( + "should be able to create a regular table with regular and geometry column should work without a workaround") { + try { + sparkSession.sql( + "CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 00000000000..a0cadc89787 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,784 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.locationtech.jts.io.{WKTReader, WKTWriter} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files +import scala.collection.mutable + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(!row.isNullAt(row.fieldIndex("id"))) + } + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + var rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + // Using partition filters + rows = shapefileDf.where("part = 2").collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + + it("should read shapes of various types") { + // There are multiple directories under shapefiles/shapetypes, each containing a shapefile. + // We'll iterate over each directory and read the shapefile within it. + val shapeTypesDir = new File(resourceFolder + "shapefiles/shapetypes") + val shapeTypeDirs = shapeTypesDir.listFiles().filter(_.isDirectory) + shapeTypeDirs.foreach { shapeTypeDir => + val fileName = shapeTypeDir.getName + val hasZ = fileName.endsWith("zm") || fileName.endsWith("z") + val hasM = fileName.endsWith("zm") || fileName.endsWith("m") + val shapeType = + if (fileName.startsWith("point")) "POINT" + else if (fileName.startsWith("linestring")) "LINESTRING" + else if (fileName.startsWith("multipoint")) "MULTIPOINT" + else "POLYGON" + val expectedWktPrefix = + if (!hasZ && !hasM) shapeType + else { + shapeType + " " + (if (hasZ) "Z" else "") + (if (hasM) "M" else "") + } + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(shapeTypeDir.getAbsolutePath) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + val rows = shapefileDf.collect() + assert(rows.length > 0) + + // Validate the geometry type and WKT prefix + val wktWriter = new WKTWriter(4) + val rowsMap = mutable.Map[String, Geometry]() + rows.foreach { row => + val id = row.getAs[String]("id") + val geom = row.getAs[Geometry]("geometry") + val wkt = wktWriter.write(geom) + assert(wkt.startsWith(expectedWktPrefix)) + assert(geom != null) + rowsMap.put(id, geom) + } + + // Validate the geometry values by reading the CSV file containing the same data + val csvDf = sparkSession.read + .format("csv") + .option("header", "true") + .load(shapeTypeDir.getAbsolutePath + "/*.csv") + val wktReader = new WKTReader() + csvDf.collect().foreach { row => + val id = row.getAs[String]("id") + val wkt = row.getAs[String]("wkt") + val geom = wktReader.read(wkt) + assert(rowsMap(id).equals(geom)) + } + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala new file mode 100644 index 00000000000..28943ff11da --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import io.minio.{MinioClient, PutObjectArgs} +import org.apache.log4j.{Level, Logger} +import org.apache.sedona.spark.SedonaContext +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.scalatest.{BeforeAndAfterAll, FunSpec} +import org.testcontainers.containers.MinIOContainer + +import java.io.FileInputStream + +import java.util.concurrent.ThreadLocalRandom + +trait TestBaseScala extends FunSpec with BeforeAndAfterAll { + Logger.getRootLogger().setLevel(Level.WARN) + Logger.getLogger("org.apache").setLevel(Level.WARN) + Logger.getLogger("com").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) + Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + + val keyParserExtension = "spark.sedona.enableParserExtension" + val warehouseLocation = System.getProperty("user.dir") + "/target/" + val sparkSession = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + // We need to be explicit about broadcasting in tests. + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") + .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) + .getOrCreate() + + val sparkSessionMinio = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.0") + .config( + "spark.hadoop.fs.s3a.aws.credentials.provider", + "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider") + .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .getOrCreate() + + val resourceFolder = System.getProperty("user.dir") + "/../common/src/test/resources/" + + override def beforeAll(): Unit = { + SedonaContext.create(sparkSession) + } + + override def afterAll(): Unit = { + // SedonaSQLRegistrator.dropAll(spark) + // spark.stop + } + + def loadCsv(path: String): DataFrame = { + sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) + } + + def withConf[T](conf: Map[String, String])(f: => T): T = { + val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key)) + conf.foreach { case (key, value) => sparkSession.conf.set(key, value) } + try { + f + } finally { + oldConf.foreach { case (key, value) => + value match { + case Some(v) => sparkSession.conf.set(key, v) + case None => sparkSession.conf.unset(key) + } + } + } + } + + def putFileIntoBucket( + bucketName: String, + key: String, + stream: FileInputStream, + client: MinioClient): Unit = { + val objectArguments = PutObjectArgs + .builder() + .bucket(bucketName) + .`object`(key) + .stream(stream, stream.available(), -1) + .build() + + client.putObject(objectArguments) + } + + def createMinioClient(container: MinIOContainer): MinioClient = { + MinioClient + .builder() + .endpoint(container.getS3URL) + .credentials(container.getUserName, container.getPassword) + .build() + } + + def adjustSparkSession(sparkSession: SparkSession, container: MinIOContainer): Unit = { + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.endpoint", container.getS3URL) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.access.key", container.getUserName) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.secret.key", container.getPassword) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.connection.timeout", "2000") + + sparkSession.sparkContext.hadoopConfiguration.set("spark.sql.debug.maxToStringFields", "100") + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.path.style.access", "true") + sparkSession.sparkContext.hadoopConfiguration + .set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala new file mode 100644 index 00000000000..8d41848de98 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.TestBaseScala +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction +import org.locationtech.jts.io.WKTReader +import org.scalatest.matchers.should.Matchers + +class StrategySuite extends TestBaseScala with Matchers { + val wktReader = new WKTReader() + + val spark: SparkSession = { + sparkSession.sparkContext.setLogLevel("ALL") + sparkSession + } + + import spark.implicits._ + + it("sedona geospatial UDF") { + val df = Seq( + (1, "value", wktReader.read("POINT(21 52)")), + (2, "value1", wktReader.read("POINT(20 50)")), + (3, "value2", wktReader.read("POINT(20 49)")), + (4, "value3", wktReader.read("POINT(20 48)")), + (5, "value4", wktReader.read("POINT(20 47)"))) + .toDF("id", "value", "geom") + .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + + df.count shouldEqual 5 + + df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + .as[String] + .collect() should contain theSameElementsAs Seq( + "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", + "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", + "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", + "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", + "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala new file mode 100644 index 00000000000..c0a2d8f260d --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF +import org.apache.spark.TestUtils +import org.apache.spark.api.python._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.util.Utils + +import java.io.File +import java.nio.file.{Files, Paths} +import scala.sys.process.Process +import scala.jdk.CollectionConverters._ + +object ScalarUDF { + + val pythonExec: String = { + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3")) + if (TestUtils.testCommandAvailable(pythonExec)) { + pythonExec + } else { + "python" + } + } + + private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") + protected lazy val sparkHome: String = { + sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) + } + + private lazy val py4jPath = + Paths.get(sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath + private[spark] lazy val pysparkPythonPath = s"$py4jPath" + + private lazy val isPythonAvailable: Boolean = TestUtils.testCommandAvailable(pythonExec) + + lazy val pythonVer: String = if (isPythonAvailable) { + Process( + Seq(pythonExec, "-c", "import sys; print('%d.%d' % sys.version_info[:2])"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!.trim() + } else { + throw new RuntimeException(s"Python executable [$pythonExec] is unavailable.") + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) + finally Utils.deleteRecursively(path) + } + + val pandasFunc: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + println(path) + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import IntegerType + |from shapely.geometry import Point + |from sedona.sql.types import GeometryType + |from pyspark.serializers import CloudPickleSerializer + |from sedona.utils import geometry_serde + |from shapely import box + |f = open('$path', 'wb'); + |def w(x): + | def apply_function(w): + | geom, offset = geometry_serde.deserialize(w) + | bounds = geom.buffer(1).bounds + | x = box(*bounds) + | return geometry_serde.serialize(x) + | return x.apply(apply_function) + |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } + + private val workerEnv = new java.util.HashMap[String, String]() + workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + + val geoPandasScalaFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "geospatial_udf", + func = SimplePythonFunction( + command = pandasFunc, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = GeometryUDT, + pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, + udfDeterministic = true) +} From 2a23a993814633f62b9fbe35ea3d767b235ce0e5 Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Thu, 12 Feb 2026 12:02:19 -0800 Subject: [PATCH 02/14] Remove Spark 4.1 from example CI --- .github/workflows/example.yml | 4 ---- .pre-commit-config.yaml | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/example.yml b/.github/workflows/example.yml index 4a93b05cc73..6d16137a839 100644 --- a/.github/workflows/example.yml +++ b/.github/workflows/example.yml @@ -45,10 +45,6 @@ jobs: fail-fast: false matrix: include: - - spark: 4.1.0 - spark-compat: '4.1' - sedona: 1.9.0-SNAPSHOT - hadoop: 3.4.2 - spark: 4.0.1 spark-compat: '4.0' sedona: 1.8.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3be967720ed..51d69a6e040 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -288,7 +288,7 @@ repos: - id: clang-format name: run clang-format description: format C files with clang-format - args: [--style=file:.github/linters/.clang-format] + args: ['--style=file:.github/linters/.clang-format'] types_or: [c] - repo: https://github.com/PyCQA/bandit rev: 1.9.3 From 3c621e5712be377387d4f7d383a652be6bc411d4 Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 00:27:58 -0800 Subject: [PATCH 03/14] Fix docs and CI: remove Scala 2.12 tabs for Spark 4.1, fix Python compatibility table, refine CI matrices --- .github/workflows/docker-build.yml | 5 +--- .github/workflows/java.yml | 13 ++++------ .github/workflows/python.yml | 41 ------------------------------ docs/setup/maven-coordinates.md | 29 --------------------- docs/setup/platform.md | 6 ++--- 5 files changed, 9 insertions(+), 85 deletions(-) diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index a7988b56aa5..86db6074c04 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -49,10 +49,7 @@ jobs: spark: ['4.0.1', '4.1.0'] include: - spark: 4.0.1 - sedona: 'latest' - geotools: '33.1' - - spark: 4.0.1 - sedona: 1.8.0 + sedona: 1.8.1 geotools: '33.1' - spark: 4.1.0 sedona: 'latest' diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 32d3d095daf..30437e052be 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -63,19 +63,16 @@ jobs: matrix: include: - spark: 4.1.0 - scala: 2.13.8 - jdk: '17' - - spark: 4.0.0 - scala: 2.13.8 + scala: 2.13.17 jdk: '17' - - spark: 3.5.4 - scala: 2.12.18 + - spark: 4.0.2 + scala: 2.13.17 jdk: '17' - - spark: 3.5.0 + - spark: 3.5.8 scala: 2.13.8 jdk: '11' skipTests: '' - - spark: 3.5.0 + - spark: 3.5.8 scala: 2.12.15 jdk: '11' skipTests: '' diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index ff3810d72b2..93f4d3f85cf 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -64,55 +64,14 @@ jobs: scala: '2.13.8' java: '17' python: '3.11' - - spark: '4.1.0' - scala: '2.13.8' - java: '17' - python: '3.10' - spark: '4.0.0' scala: '2.13.8' java: '17' - python: '3.11' - - spark: '4.0.0' - scala: '2.13.8' - java: '17' - python: '3.10' - - spark: '3.5.0' - scala: '2.12.8' - java: '11' - python: '3.11' - - spark: '3.5.0' - scala: '2.12.8' - java: '11' - python: '3.10' - shapely: '1' - - spark: '3.5.0' - scala: '2.12.8' - java: '11' python: '3.10' - spark: '3.5.0' scala: '2.12.8' java: '11' python: '3.9' - - spark: '3.5.0' - scala: '2.12.8' - java: '11' - python: '3.8' - - spark: '3.4.0' - scala: '2.12.8' - java: '11' - python: '3.11' - - spark: '3.4.0' - scala: '2.12.8' - java: '11' - python: '3.10' - - spark: '3.4.0' - scala: '2.12.8' - java: '11' - python: '3.9' - - spark: '3.4.0' - scala: '2.12.8' - java: '11' - python: '3.8' - spark: '3.4.0' scala: '2.12.8' java: '11' diff --git a/docs/setup/maven-coordinates.md b/docs/setup/maven-coordinates.md index b655ee0ff73..928200e6984 100644 --- a/docs/setup/maven-coordinates.md +++ b/docs/setup/maven-coordinates.md @@ -84,22 +84,6 @@ The optional GeoTools library is required if you want to use raster operators. V ``` - === "Spark 4.1 and Scala 2.12" - - ```xml - - org.apache.sedona - sedona-spark-shaded-4.1_2.12 - {{ sedona.current_version }} - - - - org.datasyslab - geotools-wrapper - {{ sedona.current_geotools }} - - ``` - !!! abstract "Sedona with Apache Spark and Scala 2.13" === "Spark 3.4 and Scala 2.13" @@ -255,19 +239,6 @@ The optional GeoTools library is required if you want to use raster operators. V {{ sedona.current_geotools }} ``` - === "Spark 4.1 and Scala 2.12" - ```xml - - org.apache.sedona - sedona-spark-4.1_2.12 - {{ sedona.current_version }} - - - org.datasyslab - geotools-wrapper - {{ sedona.current_geotools }} - - ``` !!! abstract "Sedona with Apache Spark and Scala 2.13" diff --git a/docs/setup/platform.md b/docs/setup/platform.md index e3113ad9b10..c788981ceda 100644 --- a/docs/setup/platform.md +++ b/docs/setup/platform.md @@ -35,10 +35,10 @@ Sedona binary releases are compiled by Java 11/17 and Scala 2.12/2.13 and tested === "Sedona Python" - | | Spark 3.4 (Scala 2.12)|Spark 3.5 (Scala 2.12)| Spark 4.0 (Scala 2.12)| Spark 4.1 (Scala 2.13)| + | | Spark 3.4 (Scala 2.12)|Spark 3.5 (Scala 2.12)| Spark 4.0 (Scala 2.13)| Spark 4.1 (Scala 2.13)| |:---------:|:---------:|:---------:|:---------:|:---------:| - | Python 3.7 | ✅ | ✅ | ✅ | ✅ | - | Python 3.8 | ✅ | ✅ | ✅ | ✅ | + | Python 3.7 | ✅ | ✅ | | | + | Python 3.8 | ✅ | ✅ | | | | Python 3.9 | ✅ | ✅ | ✅ | ✅ | | Python 3.10 | ✅ | ✅ | ✅ | ✅ | From 2e72d0fb787b4e9c0297526f49d3529de5fbebfa Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 02:34:30 -0800 Subject: [PATCH 04/14] Work around Spark 4.1 UDT case object bug (SPARK-52671) Spark 4.1's RowEncoder calls udt.getClass directly, which returns the Scala module class (e.g. GeometryUDT$) with a private constructor for case objects, causing EXPRESSION_DECODING_FAILED errors. Fix: Add apply() method to GeometryUDT, GeographyUDT, and RasterUDT case objects that return new class instances, and use UDT() instead of the bare singleton throughout schema construction code. This ensures getClass returns the public class with an accessible constructor. Also: - Revert docker-build.yml (no Spark 4.1 in Docker builds) - Bump pyspark upper bound from <4.1.0 to <4.2.0 - Bump Spark 4.1.0 to 4.1.1 in CI and POM - Fix Scala 2.13.12 vs 2.13.17 mismatch in scala2.13 profile --- .github/workflows/docker-build.yml | 8 +++---- .github/workflows/java.yml | 2 +- .github/workflows/python.yml | 2 +- pom.xml | 6 ++--- python/pyproject.toml | 6 ++--- .../geopackage/model/GeoPackageField.scala | 18 +++++++------- .../sql/datasources/spider/SpiderTable.scala | 2 +- .../org/apache/sedona/sql/utils/Adapter.scala | 6 ++--- .../geoparquet/GeoParquetFileFormat.scala | 2 +- .../GeoParquetSchemaConverter.scala | 2 +- .../sql/sedona_sql/UDT/GeographyUDT.scala | 4 +++- .../sql/sedona_sql/UDT/GeometryUDT.scala | 4 +++- .../spark/sql/sedona_sql/UDT/RasterUDT.scala | 4 +++- .../sedona_sql/expressions/Constructors.scala | 12 +++++----- .../sedona_sql/expressions/Functions.scala | 20 ++++++++-------- .../expressions/GeoStatsFunctions.scala | 8 +++---- .../expressions/InferrableRasterTypes.scala | 4 ++-- .../expressions/InferredExpression.scala | 8 +++---- .../sedona_sql/expressions/Predicates.scala | 2 +- .../expressions/collect/ST_Collect.scala | 2 +- .../expressions/raster/PixelFunctions.scala | 12 +++++----- .../raster/RasterConstructors.scala | 2 +- .../expressions/raster/RasterFunctions.scala | 24 +++++++++---------- .../expressions/raster/RasterPredicates.scala | 6 ++--- .../io/geojson/GeoJSONFileFormat.scala | 4 ++-- .../sedona_sql/io/stac/StacDataSource.scala | 2 +- .../sql/sedona_sql/io/stac/StacTable.scala | 2 +- .../org/apache/sedona/sql/KnnJoinSuite.scala | 2 +- .../apache/sedona/sql/PreserveSRIDSuite.scala | 12 +++++----- .../apache/sedona/sql/SpatialJoinSuite.scala | 2 +- .../apache/sedona/sql/adapterTestScala.scala | 10 ++++---- .../apache/sedona/sql/geoparquetIOTests.scala | 16 ++++++------- .../shapefile/ShapefileUtils.scala | 4 ++-- .../sql/parser/SedonaSqlAstBuilder.scala | 2 +- .../sedona/sql/GeoPackageReaderTest.scala | 2 +- .../sedona/sql/GeoParquetMetadataTests.scala | 4 ++-- .../apache/sedona/sql/ShapefileTests.scala | 2 +- .../shapefile/ShapefileUtils.scala | 4 ++-- .../sql/parser/SedonaSqlAstBuilder.scala | 2 +- .../sedona/sql/GeoPackageReaderTest.scala | 2 +- .../sedona/sql/GeoParquetMetadataTests.scala | 4 ++-- .../apache/sedona/sql/ShapefileTests.scala | 2 +- .../spark/sql/udf/TestScalarPandasUDF.scala | 2 +- .../shapefile/ShapefileUtils.scala | 4 ++-- .../sql/parser/SedonaSqlAstBuilder.scala | 2 +- .../sedona/sql/GeoPackageReaderTest.scala | 2 +- .../sedona/sql/GeoParquetMetadataTests.scala | 4 ++-- .../apache/sedona/sql/ShapefileTests.scala | 2 +- .../spark/sql/udf/TestScalarPandasUDF.scala | 2 +- .../shapefile/ShapefileUtils.scala | 4 ++-- .../sql/parser/SedonaSqlAstBuilder.scala | 2 +- .../sedona/sql/GeoPackageReaderTest.scala | 2 +- .../sedona/sql/GeoParquetMetadataTests.scala | 4 ++-- .../apache/sedona/sql/ShapefileTests.scala | 2 +- .../spark/sql/udf/TestScalarPandasUDF.scala | 2 +- 55 files changed, 142 insertions(+), 136 deletions(-) diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 86db6074c04..42306f9ee1d 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -46,14 +46,14 @@ jobs: fail-fast: true matrix: os: ['ubuntu-latest', 'ubuntu-24.04-arm'] - spark: ['4.0.1', '4.1.0'] + spark: ['4.0.1'] include: - spark: 4.0.1 - sedona: 1.8.1 - geotools: '33.1' - - spark: 4.1.0 sedona: 'latest' geotools: '33.1' + - spark: 4.0.1 + sedona: 1.8.0 + geotools: '33.1' runs-on: ${{ matrix.os }} defaults: run: diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 30437e052be..027308dde73 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -62,7 +62,7 @@ jobs: fail-fast: true matrix: include: - - spark: 4.1.0 + - spark: 4.1.1 scala: 2.13.17 jdk: '17' - spark: 4.0.2 diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 93f4d3f85cf..0b7d524e0c5 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -60,7 +60,7 @@ jobs: strategy: matrix: include: - - spark: '4.1.0' + - spark: '4.1.1' scala: '2.13.8' java: '17' python: '3.11' diff --git a/pom.xml b/pom.xml index d4fe817809f..05ca1cde9c0 100644 --- a/pom.xml +++ b/pom.xml @@ -758,7 +758,7 @@ 2.24.3 2.0.16 - 2.13.12 + 2.13.17 2.13 @@ -774,7 +774,7 @@ - 4.1.0 + 4.1.1 4.1 4 3.4.1 @@ -798,7 +798,7 @@ false - 2.13.12 + 2.13.17 2.13 -no-java-comments diff --git a/python/pyproject.toml b/python/pyproject.toml index 4d7eda43466..440b42642e3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -37,13 +37,13 @@ dependencies = [ ] [project.optional-dependencies] -spark = ["pyspark>=3.4.0,<4.1.0"] +spark = ["pyspark>=3.4.0,<4.2.0"] pydeck-map = ["geopandas", "pydeck==0.8.0"] kepler-map = ["geopandas", "keplergl==0.3.2"] flink = ["apache-flink>=1.19.0"] db = ["sedonadb[geopandas]; python_version >= '3.9'"] all = [ - "pyspark>=3.4.0,<4.1.0", + "pyspark>=3.4.0,<4.2.0", "geopandas", "pydeck==0.8.0", "keplergl==0.3.2", @@ -71,7 +71,7 @@ dev = [ # cannot set geopandas>=0.14.4 since it doesn't support python 3.8, so we pin fiona to <1.10.0 "fiona<1.10.0", "pyarrow", - "pyspark>=3.4.0,<4.1.0", + "pyspark>=3.4.0,<4.2.0", "keplergl==0.3.2", "pydeck==0.8.0", "pystac==1.5.0", diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/geopackage/model/GeoPackageField.scala b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/geopackage/model/GeoPackageField.scala index 127c56ca52f..7e3dba3ad32 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/geopackage/model/GeoPackageField.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/geopackage/model/GeoPackageField.scala @@ -33,7 +33,7 @@ case class GeoPackageField(name: String, dataType: String, isNullable: Boolean) StructField(name, StringType) case startsWith: String if startsWith.startsWith(GeoPackageType.BLOB) => { if (tableType == TableType.TILES) { - return StructField(name, RasterUDT) + return StructField(name, RasterUDT()) } StructField(name, BinaryType) @@ -41,14 +41,14 @@ case class GeoPackageField(name: String, dataType: String, isNullable: Boolean) case GeoPackageType.INTEGER | GeoPackageType.INT | GeoPackageType.SMALLINT | GeoPackageType.TINY_INT | GeoPackageType.MEDIUMINT => StructField(name, IntegerType) - case GeoPackageType.POINT => StructField(name, GeometryUDT) - case GeoPackageType.LINESTRING => StructField(name, GeometryUDT) - case GeoPackageType.POLYGON => StructField(name, GeometryUDT) - case GeoPackageType.GEOMETRY => StructField(name, GeometryUDT) - case GeoPackageType.MULTIPOINT => StructField(name, GeometryUDT) - case GeoPackageType.MULTILINESTRING => StructField(name, GeometryUDT) - case GeoPackageType.MULTIPOLYGON => StructField(name, GeometryUDT) - case GeoPackageType.GEOMETRYCOLLECTION => StructField(name, GeometryUDT) + case GeoPackageType.POINT => StructField(name, GeometryUDT()) + case GeoPackageType.LINESTRING => StructField(name, GeometryUDT()) + case GeoPackageType.POLYGON => StructField(name, GeometryUDT()) + case GeoPackageType.GEOMETRY => StructField(name, GeometryUDT()) + case GeoPackageType.MULTIPOINT => StructField(name, GeometryUDT()) + case GeoPackageType.MULTILINESTRING => StructField(name, GeometryUDT()) + case GeoPackageType.MULTIPOLYGON => StructField(name, GeometryUDT()) + case GeoPackageType.GEOMETRYCOLLECTION => StructField(name, GeometryUDT()) case GeoPackageType.REAL => StructField(name, DoubleType) case GeoPackageType.BOOLEAN => StructField(name, BooleanType) case GeoPackageType.DATE => StructField(name, DateType) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/spider/SpiderTable.scala b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/spider/SpiderTable.scala index 81cb5bffecd..8acbef285b4 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/spider/SpiderTable.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/spider/SpiderTable.scala @@ -50,5 +50,5 @@ class SpiderTable( object SpiderTable { val SCHEMA: StructType = StructType( - Seq(StructField("id", LongType), StructField("geometry", GeometryUDT))) + Seq(StructField("id", LongType), StructField("geometry", GeometryUDT()))) } diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala index 96aab1287ee..732fb1272d7 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala @@ -143,7 +143,7 @@ object Adapter { val stringRow = extractUserData(geom) Row.fromSeq(stringRow) }) - var cols: Seq[StructField] = Seq(StructField("geometry", GeometryUDT)) + var cols: Seq[StructField] = Seq(StructField("geometry", GeometryUDT())) if (fieldNames != null && fieldNames.nonEmpty) { cols = cols ++ fieldNames.map(f => StructField(f, StringType)) } @@ -195,10 +195,10 @@ object Adapter { rightFieldnames = rightFieldNames) Row.fromSeq(stringRow) }) - var cols: Seq[StructField] = Seq(StructField("leftgeometry", GeometryUDT)) + var cols: Seq[StructField] = Seq(StructField("leftgeometry", GeometryUDT())) if (leftFieldnames != null && leftFieldnames.nonEmpty) cols = cols ++ leftFieldnames.map(fName => StructField(fName, StringType)) - cols = cols ++ Seq(StructField("rightgeometry", GeometryUDT)) + cols = cols ++ Seq(StructField("rightgeometry", GeometryUDT())) if (rightFieldNames != null && rightFieldNames.nonEmpty) cols = cols ++ rightFieldNames.map(fName => StructField(fName, StringType)) val schema = StructType(cols) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala index 04f4db8bf5a..8a5c2c19a0f 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala @@ -448,7 +448,7 @@ object GeoParquetFileFormat extends Logging { val fields = schema.fields.map { field => field.dataType match { case _: BinaryType if geoParquetMetaData.columns.contains(field.name) => - field.copy(dataType = GeometryUDT) + field.copy(dataType = GeometryUDT()) case _ => field } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSchemaConverter.scala b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSchemaConverter.scala index 749816ffcb5..34b9f44ec54 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSchemaConverter.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSchemaConverter.scala @@ -193,7 +193,7 @@ class GeoParquetToSparkSchemaConverter( case BINARY => originalType match { case UTF8 | ENUM | JSON => StringType - case null if isGeometryField(field.getName) => GeometryUDT + case null if isGeometryField(field.getName) => GeometryUDT() case null if assumeBinaryIsString => StringType case null => BinaryType case BSON => BinaryType diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala index d4d4fb8898c..987ef9c229d 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala @@ -56,4 +56,6 @@ class GeographyUDT extends UserDefinedType[Geography] { case object GeographyUDT extends org.apache.spark.sql.sedona_sql.UDT.GeographyUDT - with scala.Serializable + with scala.Serializable { + def apply(): GeographyUDT = new GeographyUDT() +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala index 20ed85d52ec..59b5b82d6a5 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala @@ -57,4 +57,6 @@ class GeometryUDT extends UserDefinedType[Geometry] { case object GeometryUDT extends org.apache.spark.sql.sedona_sql.UDT.GeometryUDT - with scala.Serializable + with scala.Serializable { + def apply(): GeometryUDT = new GeometryUDT() +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala index a34b372211e..77c3254eb5a 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala @@ -58,4 +58,6 @@ class RasterUDT extends UserDefinedType[GridCoverage2D] { override def hashCode(): Int = userClass.hashCode() } -case object RasterUDT extends RasterUDT with Serializable +case object RasterUDT extends RasterUDT with Serializable { + def apply(): RasterUDT = new RasterUDT() +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala index d4b3d3efb1b..ce81764e7e7 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala @@ -177,7 +177,7 @@ private[apache] case class ST_GeomFromWKB(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -217,7 +217,7 @@ private[apache] case class ST_GeomFromEWKB(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -272,7 +272,7 @@ private[apache] case class ST_LineFromWKB(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = if (inputExpressions.length == 1) Seq(TypeCollection(StringType, BinaryType)) @@ -329,7 +329,7 @@ private[apache] case class ST_LinestringFromWKB(inputExpressions: Seq[Expression } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = if (inputExpressions.length == 1) Seq(TypeCollection(StringType, BinaryType)) @@ -386,7 +386,7 @@ private[apache] case class ST_PointFromWKB(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = if (inputExpressions.length == 1) Seq(TypeCollection(StringType, BinaryType)) @@ -435,7 +435,7 @@ private[apache] case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def children: Seq[Expression] = inputExpressions diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 76d0ec57602..b5f85b89682 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -394,9 +394,9 @@ private[apache] case class ST_IsValidDetail(children: Seq[Expression]) override def inputTypes: Seq[AbstractDataType] = { if (nArgs == 2) { - Seq(GeometryUDT, IntegerType) + Seq(GeometryUDT(), IntegerType) } else if (nArgs == 1) { - Seq(GeometryUDT) + Seq(GeometryUDT()) } else { throw new IllegalArgumentException(s"Invalid number of arguments: $nArgs") } @@ -440,7 +440,7 @@ private[apache] case class ST_IsValidDetail(children: Seq[Expression]) override def dataType: DataType = new StructType() .add("valid", BooleanType, nullable = false) .add("reason", StringType, nullable = true) - .add("location", GeometryUDT, nullable = true) + .add("location", GeometryUDT(), nullable = true) } private[apache] case class ST_IsValidTrajectory(inputExpressions: Seq[Expression]) @@ -736,7 +736,7 @@ private[apache] case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expres override def dataType: DataType = DataTypes.createStructType( Array( - DataTypes.createStructField("center", GeometryUDT, false), + DataTypes.createStructField("center", GeometryUDT(), false), DataTypes.createStructField("radius", DataTypes.DoubleType, false))) override def children: Seq[Expression] = inputExpressions @@ -1069,7 +1069,7 @@ private[apache] case class ST_SubDivideExplode(children: Seq[Expression]) override def elementSchema: StructType = { new StructType() - .add("geom", GeometryUDT, true) + .add("geom", GeometryUDT(), true) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -1188,8 +1188,8 @@ private[apache] case class ST_MaximumInscribedCircle(children: Seq[Expression]) override def nullable: Boolean = true override def dataType: DataType = new StructType() - .add("center", GeometryUDT, nullable = false) - .add("nearest", GeometryUDT, nullable = false) + .add("center", GeometryUDT(), nullable = false) + .add("nearest", GeometryUDT(), nullable = false) .add("radius", DoubleType, nullable = false) } @@ -1693,13 +1693,13 @@ private[apache] case class ST_GeneratePoints(inputExpressions: Seq[Expression], override def nullable: Boolean = true - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def inputTypes: Seq[AbstractDataType] = { if (nArgs == 3) { - Seq(GeometryUDT, IntegerType, IntegerType) + Seq(GeometryUDT(), IntegerType, IntegerType) } else if (nArgs == 2) { - Seq(GeometryUDT, IntegerType) + Seq(GeometryUDT(), IntegerType) } else { throw new IllegalArgumentException(s"Invalid number of arguments: $nArgs") } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala index 0d805fbedee..eb6d670d3f2 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala @@ -36,7 +36,7 @@ private[apache] case class ST_DBSCAN(children: Seq[Expression]) StructType(Seq(StructField("isCore", BooleanType), StructField("cluster", LongType))) override def inputTypes: Seq[AbstractDataType] = - Seq(GeometryUDT, DoubleType, IntegerType, BooleanType) + Seq(GeometryUDT(), DoubleType, IntegerType, BooleanType) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) @@ -74,7 +74,7 @@ private[apache] case class ST_LocalOutlierFactor(children: Seq[Expression]) override def dataType: DataType = DoubleType override def inputTypes: Seq[AbstractDataType] = - Seq(GeometryUDT, IntegerType, BooleanType) + Seq(GeometryUDT(), IntegerType, BooleanType) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) @@ -139,7 +139,7 @@ private[apache] case class ST_BinaryDistanceBandColumn(children: Seq[Expression] Seq(StructField("neighbor", children(5).dataType), StructField("value", DoubleType)))) override def inputTypes: Seq[AbstractDataType] = - Seq(GeometryUDT, DoubleType, BooleanType, BooleanType, BooleanType, children(5).dataType) + Seq(GeometryUDT(), DoubleType, BooleanType, BooleanType, BooleanType, children(5).dataType) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) @@ -174,7 +174,7 @@ private[apache] case class ST_WeightedDistanceBandColumn(children: Seq[Expressio override def inputTypes: Seq[AbstractDataType] = Seq( - GeometryUDT, + GeometryUDT(), DoubleType, DoubleType, BooleanType, diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala index 2d3349d4ad7..970f03d01a7 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala @@ -37,8 +37,8 @@ object InferrableRasterTypes { def isRasterType(t: Type): Boolean = t =:= typeOf[GridCoverage2D] def isRasterArrayType(t: Type): Boolean = t =:= typeOf[Array[GridCoverage2D]] - val rasterUDT: UserDefinedType[_] = RasterUDT - val rasterUDTArray: ArrayType = DataTypes.createArrayType(RasterUDT) + val rasterUDT: UserDefinedType[_] = RasterUDT() + val rasterUDTArray: ArrayType = DataTypes.createArrayType(RasterUDT()) def rasterExtractor(expr: Expression)(input: InternalRow): Any = expr.toRaster(input) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala index 757948216c3..4b94084f95a 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala @@ -325,13 +325,13 @@ object InferredTypes { def inferSparkType(t: Type): DataType = { if (t =:= typeOf[Geometry]) { - GeometryUDT + GeometryUDT() } else if (t =:= typeOf[Array[Geometry]] || t =:= typeOf[java.util.List[Geometry]]) { - DataTypes.createArrayType(GeometryUDT) + DataTypes.createArrayType(GeometryUDT()) } else if (t =:= typeOf[Geography]) { - GeographyUDT + GeographyUDT() } else if (t =:= typeOf[Array[Geography]] || t =:= typeOf[java.util.List[Geography]]) { - DataTypes.createArrayType(GeographyUDT) + DataTypes.createArrayType(GeographyUDT()) } else if (InferredRasterExpression.isRasterType(t)) { InferredRasterExpression.rasterUDT } else if (InferredRasterExpression.isRasterArrayType(t)) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala index 9e69ed04382..fff9f6eeef5 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala @@ -40,7 +40,7 @@ abstract class ST_Predicate override def nullable: Boolean = children.exists(_.nullable) - override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, GeometryUDT) + override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT(), GeometryUDT()) override def dataType: DataType = BooleanType diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala index 3b430cc67df..4892fb5a8b3 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala @@ -82,7 +82,7 @@ private[apache] case class ST_Collect(inputExpressions: Seq[Expression]) } } - override def dataType: DataType = GeometryUDT + override def dataType: DataType = GeometryUDT() override def children: Seq[Expression] = inputExpressions diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala index f72cd1e3946..13849f74f48 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala @@ -59,7 +59,7 @@ private[apache] case class RS_PixelAsPoints(inputExpressions: Seq[Expression]) override def dataType: DataType = ArrayType( new StructType() - .add("geom", GeometryUDT) + .add("geom", GeometryUDT()) .add("value", DoubleType) .add("x", IntegerType) .add("y", IntegerType)) @@ -88,7 +88,7 @@ private[apache] case class RS_PixelAsPoints(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT(), IntegerType) } private[apache] case class RS_PixelAsPolygon(inputExpressions: Seq[Expression]) @@ -107,7 +107,7 @@ private[apache] case class RS_PixelAsPolygons(inputExpressions: Seq[Expression]) override def dataType: DataType = ArrayType( new StructType() - .add("geom", GeometryUDT) + .add("geom", GeometryUDT()) .add("value", DoubleType) .add("x", IntegerType) .add("y", IntegerType)) @@ -137,7 +137,7 @@ private[apache] case class RS_PixelAsPolygons(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT(), IntegerType) } private[apache] case class RS_PixelAsCentroid(inputExpressions: Seq[Expression]) @@ -156,7 +156,7 @@ private[apache] case class RS_PixelAsCentroids(inputExpressions: Seq[Expression] override def dataType: DataType = ArrayType( new StructType() - .add("geom", GeometryUDT) + .add("geom", GeometryUDT()) .add("value", DoubleType) .add("x", IntegerType) .add("y", IntegerType)) @@ -186,7 +186,7 @@ private[apache] case class RS_PixelAsCentroids(inputExpressions: Seq[Expression] copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT(), IntegerType) } private[apache] case class RS_Values(inputExpressions: Seq[Expression]) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala index 77d4e700b04..2523e893437 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala @@ -147,7 +147,7 @@ private[apache] case class RS_TileExplode(children: Seq[Expression]) new StructType() .add("x", IntegerType, nullable = false) .add("y", IntegerType, nullable = false) - .add("tile", RasterUDT, nullable = false) + .add("tile", RasterUDT(), nullable = false) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterFunctions.scala index b1fa43b5292..67d439b7d71 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterFunctions.scala @@ -83,7 +83,7 @@ private[apache] case class RS_Metadata(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT()) } private[apache] case class RS_SummaryStatsAll(inputExpressions: Seq[Expression]) @@ -139,13 +139,13 @@ private[apache] case class RS_SummaryStatsAll(inputExpressions: Seq[Expression]) override def inputTypes: Seq[AbstractDataType] = { if (inputExpressions.length == 1) { - Seq(RasterUDT) + Seq(RasterUDT()) } else if (inputExpressions.length == 2) { - Seq(RasterUDT, IntegerType) + Seq(RasterUDT(), IntegerType) } else if (inputExpressions.length == 3) { - Seq(RasterUDT, IntegerType, BooleanType) + Seq(RasterUDT(), IntegerType, BooleanType) } else { - Seq(RasterUDT) + Seq(RasterUDT()) } } } @@ -220,17 +220,17 @@ private[apache] case class RS_ZonalStatsAll(inputExpressions: Seq[Expression]) override def inputTypes: Seq[AbstractDataType] = { if (inputExpressions.length == 2) { - Seq(RasterUDT, GeometryUDT) + Seq(RasterUDT(), GeometryUDT()) } else if (inputExpressions.length == 3) { - Seq(RasterUDT, GeometryUDT, IntegerType) + Seq(RasterUDT(), GeometryUDT(), IntegerType) } else if (inputExpressions.length == 4) { - Seq(RasterUDT, GeometryUDT, IntegerType, BooleanType) + Seq(RasterUDT(), GeometryUDT(), IntegerType, BooleanType) } else if (inputExpressions.length == 5) { - Seq(RasterUDT, GeometryUDT, IntegerType, BooleanType, BooleanType) + Seq(RasterUDT(), GeometryUDT(), IntegerType, BooleanType, BooleanType) } else if (inputExpressions.length >= 6) { - Seq(RasterUDT, GeometryUDT, IntegerType, BooleanType, BooleanType, BooleanType) + Seq(RasterUDT(), GeometryUDT(), IntegerType, BooleanType, BooleanType, BooleanType) } else { - Seq(RasterUDT, GeometryUDT) + Seq(RasterUDT(), GeometryUDT()) } } } @@ -276,5 +276,5 @@ private[apache] case class RS_GeoTransform(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } - override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT) + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT()) } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala index e247a2825b5..35aaa276c04 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala @@ -47,9 +47,9 @@ abstract class RS_Predicate val leftType = inputExpressions.head.dataType val rightType = inputExpressions(1).dataType (leftType, rightType) match { - case (_: RasterUDT, _: GeometryUDT) => Seq(RasterUDT, GeometryUDT) - case (_: GeometryUDT, _: RasterUDT) => Seq(GeometryUDT, RasterUDT) - case (_: RasterUDT, _: RasterUDT) => Seq(RasterUDT, RasterUDT) + case (_: RasterUDT, _: GeometryUDT) => Seq(RasterUDT(), GeometryUDT()) + case (_: GeometryUDT, _: RasterUDT) => Seq(GeometryUDT(), RasterUDT()) + case (_: RasterUDT, _: RasterUDT) => Seq(RasterUDT(), RasterUDT()) case _ => throw new IllegalArgumentException(s"Unsupported input types: $leftType, $rightType") } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONFileFormat.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONFileFormat.scala index ca588c33bea..67dafb4e45b 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONFileFormat.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONFileFormat.scala @@ -69,7 +69,7 @@ class GeoJSONFileFormat extends TextBasedFileFormat with DataSourceRegister { fullSchemaOption.map { fullSchema => // Replace 'geometry' field type with GeometryUDT - val newFields = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT) + val newFields = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT()) StructType(newFields) } } @@ -131,7 +131,7 @@ class GeoJSONFileFormat extends TextBasedFileFormat with DataSourceRegister { geometryColumnName.split('.'), resolver = SQLConf.get.resolver) match { case Some(StructField(_, dataType, _, _)) => - if (!dataType.acceptsType(GeometryUDT)) { + if (!dataType.acceptsType(GeometryUDT())) { throw new IllegalArgumentException(s"$geometryColumnName is not a geometry column") } case None => diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala index 154adf8eca4..0aad6e3d48b 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala @@ -75,7 +75,7 @@ class StacDataSource() extends TableProvider with DataSourceRegister { // Check if the schema is already cached val fullSchema = schemaCache.computeIfAbsent(optsMap, _ => inferStacSchema(optsMap)) - val updatedGeometrySchema = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT) + val updatedGeometrySchema = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT()) updatePropertiesPromotedSchema(updatedGeometrySchema) } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala index ca5f32663ca..1da2af3dd62 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala @@ -67,7 +67,7 @@ class StacTable( override def schema(): StructType = { // Check if the schema is already cached val fullSchema = schemaCache.computeIfAbsent(opts, _ => inferStacSchema(opts)) - val updatedGeometrySchema = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT) + val updatedGeometrySchema = GeoJSONUtils.updateGeometrySchema(fullSchema, GeometryUDT()) updatePropertiesPromotedSchema(updatedGeometrySchema) } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala index 50696337f7f..4ca17598d30 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala @@ -550,7 +550,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { val emptyRdd = sparkSession.sparkContext.emptyRDD[Row] val emptyDf = sparkSession.createDataFrame( emptyRdd, - StructType(Seq(StructField("id", IntegerType), StructField("geom", GeometryUDT)))) + StructType(Seq(StructField("id", IntegerType), StructField("geom", GeometryUDT())))) emptyDf.createOrReplaceTempView("EMPTYTABLE") df1.createOrReplaceTempView("df1") diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala index 69b1641cc40..749a1be48a2 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala @@ -144,12 +144,12 @@ class PreserveSRIDSuite extends TestBaseScala with TableDrivenPropertyChecks { val schema = StructType( Seq( - StructField("geom1", GeometryUDT), - StructField("geom2", GeometryUDT), - StructField("geom3", GeometryUDT), - StructField("geom4", GeometryUDT), - StructField("geom5", GeometryUDT), - StructField("geom6", GeometryUDT))) + StructField("geom1", GeometryUDT()), + StructField("geom2", GeometryUDT()), + StructField("geom3", GeometryUDT()), + StructField("geom4", GeometryUDT()), + StructField("geom5", GeometryUDT()), + StructField("geom6", GeometryUDT()))) val geom1 = Constructors.geomFromWKT("POLYGON ((0 0, 1 0, 0.5 0.5, 1 1, 0 1, 0 0))", 1000) val geom2 = Constructors.geomFromWKT("MULTILINESTRING ((0 0, 0 1), (0 1, 1 1), (1 1, 1 0))", 1000) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala index 9f4fd30f95b..56bd3bf23f1 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala @@ -308,7 +308,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { val emptyRdd = sparkSession.sparkContext.emptyRDD[Row] val emptyDf = sparkSession.createDataFrame( emptyRdd, - StructType(Seq(StructField("id", IntegerType), StructField("geom", GeometryUDT)))) + StructType(Seq(StructField("id", IntegerType), StructField("geom", GeometryUDT())))) df1.createOrReplaceTempView("df1") df2.createOrReplaceTempView("df2") emptyDf.createOrReplaceTempView("dfEmpty") diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala index 0d267e62567..2ca9f359543 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala @@ -319,7 +319,7 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { // Tweak the column names to camelCase to ensure it also renames val schema = StructType( Array( - StructField("leftGeometry", GeometryUDT, nullable = true), + StructField("leftGeometry", GeometryUDT(), nullable = true), StructField("exampleText", StringType, nullable = true), StructField("exampleDouble", DoubleType, nullable = true), StructField("exampleInt", IntegerType, nullable = true))) @@ -408,7 +408,7 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { // Tweak the column names to camelCase to ensure it also renames val schema = StructType( Array( - StructField("leftGeometry", GeometryUDT, nullable = true), + StructField("leftGeometry", GeometryUDT(), nullable = true), StructField("exampleText", StringType, nullable = true), StructField("exampleFloat", FloatType, nullable = true), StructField("exampleDouble", DoubleType, nullable = true), @@ -424,7 +424,7 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { StructField("structText", StringType, nullable = true), StructField("structInt", IntegerType, nullable = true), StructField("structBool", BooleanType, nullable = true)))), - StructField("rightGeometry", GeometryUDT, nullable = true), + StructField("rightGeometry", GeometryUDT(), nullable = true), // We have to include a column for right user data (even though there is none) // since there is no way to distinguish between no data and nullable data StructField("rightUserData", StringType, nullable = true))) @@ -493,11 +493,11 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { // Convert to DataFrame val schema = StructType( Array( - StructField("leftgeometry", GeometryUDT, nullable = true), + StructField("leftgeometry", GeometryUDT(), nullable = true), StructField("exampletext", StringType, nullable = true), StructField("exampledouble", StringType, nullable = true), StructField("exampleint", StringType, nullable = true), - StructField("rightgeometry", GeometryUDT, nullable = true), + StructField("rightgeometry", GeometryUDT(), nullable = true), StructField("userdata", StringType, nullable = true))) val joinResultDf = Adapter.toDf(joinResultPairRDD, schema, sparkSession) val resultWithoutSchema = Adapter.toDf( diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala index c3ef8dd89ef..7628c9a6daa 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala @@ -169,7 +169,7 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val geomTypes = (geo \ "columns" \ "geometry" \ "geometry_types").extract[Seq[String]] assert(geomTypes.nonEmpty) val sparkSqlRowMetadata = metadata.get(ParquetReadSupport.SPARK_METADATA_KEY) - assert(!sparkSqlRowMetadata.contains("GeometryUDT")) + assert(!sparkSqlRowMetadata.contains("GeometryUDT()")) } } it("GEOPARQUET Test example-1.1.0.parquet") { @@ -206,8 +206,8 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("g0", GeometryUDT, nullable = false), - StructField("g1", GeometryUDT, nullable = false))) + StructField("g0", GeometryUDT(), nullable = false), + StructField("g1", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) val geoParquetSavePath = geoparquetoutputlocation + "/multi_geoms.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) @@ -241,7 +241,7 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("g", GeometryUDT, nullable = false))) + StructField("g", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/empty.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) @@ -262,7 +262,7 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column", GeometryUDT, nullable = false))) + StructField("geom_column", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/snake_case_column_name.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) @@ -277,7 +277,7 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geomColumn", GeometryUDT, nullable = false))) + StructField("geomColumn", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/camel_case_column_name.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) @@ -400,8 +400,8 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("g0", GeometryUDT, nullable = false), - StructField("g1", GeometryUDT, nullable = false))) + StructField("g0", GeometryUDT(), nullable = false), + StructField("g1", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) val projjson0 = diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala index 12238870be1..efd6098d518 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -141,9 +141,9 @@ object ShapefileUtils { throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") } StructType( - Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) case _ => - StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) } } diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala index 2bdd92bd64a..b56ed11c875 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -32,7 +32,7 @@ class SedonaSqlAstBuilder extends SparkSqlAstBuilder { */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { ctx.getText.toUpperCase() match { - case "GEOMETRY" => GeometryUDT + case "GEOMETRY" => GeometryUDT() case _ => super.visitPrimitiveDataType(ctx) } } diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala index 1e4b071361b..166d8c48dba 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -44,7 +44,7 @@ class GeoPackageReaderTest extends TestBaseScala with Matchers { val expectedFeatureSchema = StructType( Seq( StructField("id", IntegerType, true), - StructField("geometry", GeometryUDT, true), + StructField("geometry", GeometryUDT(), true), StructField("text", StringType, true), StructField("real", DoubleType, true), StructField("boolean", BooleanType, true), diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala index 421890c7001..01306c1b452 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -120,8 +120,8 @@ class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column_1", GeometryUDT, nullable = false), - StructField("geomColumn2", GeometryUDT, nullable = false))) + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index bb53131475f..c2e88c469be 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -698,7 +698,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { Seq( StructField("osm_id", StringType), StructField("code2", LongType), - StructField("geometry", GeometryUDT))) + StructField("geometry", GeometryUDT()))) val shapefileDf = sparkSession.read .format("shapefile") .schema(customSchema) diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala index 04ac3bdff9b..fd6d1e83827 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -141,9 +141,9 @@ object ShapefileUtils { throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") } StructType( - Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) case _ => - StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala index 2bdd92bd64a..b56ed11c875 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -32,7 +32,7 @@ class SedonaSqlAstBuilder extends SparkSqlAstBuilder { */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { ctx.getText.toUpperCase() match { - case "GEOMETRY" => GeometryUDT + case "GEOMETRY" => GeometryUDT() case _ => super.visitPrimitiveDataType(ctx) } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala index 6d9f41bf4e3..66fa147bc58 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -44,7 +44,7 @@ class GeoPackageReaderTest extends TestBaseScala with Matchers { val expectedFeatureSchema = StructType( Seq( StructField("id", IntegerType, true), - StructField("geometry", GeometryUDT, true), + StructField("geometry", GeometryUDT(), true), StructField("text", StringType, true), StructField("real", DoubleType, true), StructField("boolean", BooleanType, true), diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala index 421890c7001..01306c1b452 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -120,8 +120,8 @@ class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column_1", GeometryUDT, nullable = false), - StructField("geomColumn2", GeometryUDT, nullable = false))) + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index a0cadc89787..275bc3282f9 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -710,7 +710,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { Seq( StructField("osm_id", StringType), StructField("code2", LongType), - StructField("geometry", GeometryUDT))) + StructField("geometry", GeometryUDT()))) val shapefileDf = sparkSession.read .format("shapefile") .schema(customSchema) diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index c0a2d8f260d..62733169288 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -116,7 +116,7 @@ object ScalarUDF { pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, accumulator = null), - dataType = GeometryUDT, + dataType = GeometryUDT(), pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, udfDeterministic = true) } diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala index 04ac3bdff9b..fd6d1e83827 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -141,9 +141,9 @@ object ShapefileUtils { throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") } StructType( - Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) case _ => - StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) } } diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala index 2bdd92bd64a..b56ed11c875 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -32,7 +32,7 @@ class SedonaSqlAstBuilder extends SparkSqlAstBuilder { */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { ctx.getText.toUpperCase() match { - case "GEOMETRY" => GeometryUDT + case "GEOMETRY" => GeometryUDT() case _ => super.visitPrimitiveDataType(ctx) } } diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala index 6d9f41bf4e3..66fa147bc58 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -44,7 +44,7 @@ class GeoPackageReaderTest extends TestBaseScala with Matchers { val expectedFeatureSchema = StructType( Seq( StructField("id", IntegerType, true), - StructField("geometry", GeometryUDT, true), + StructField("geometry", GeometryUDT(), true), StructField("text", StringType, true), StructField("real", DoubleType, true), StructField("boolean", BooleanType, true), diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala index 421890c7001..01306c1b452 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -120,8 +120,8 @@ class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column_1", GeometryUDT, nullable = false), - StructField("geomColumn2", GeometryUDT, nullable = false))) + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index a0cadc89787..275bc3282f9 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -710,7 +710,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { Seq( StructField("osm_id", StringType), StructField("code2", LongType), - StructField("geometry", GeometryUDT))) + StructField("geometry", GeometryUDT()))) val shapefileDf = sparkSession.read .format("shapefile") .schema(customSchema) diff --git a/spark/spark-4.0/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-4.0/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index c0a2d8f260d..62733169288 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -116,7 +116,7 @@ object ScalarUDF { pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, accumulator = null), - dataType = GeometryUDT, + dataType = GeometryUDT(), pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, udfDeterministic = true) } diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala index 04ac3bdff9b..fd6d1e83827 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -141,9 +141,9 @@ object ShapefileUtils { throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") } StructType( - Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) case _ => - StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) } } diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala index 2bdd92bd64a..b56ed11c875 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -32,7 +32,7 @@ class SedonaSqlAstBuilder extends SparkSqlAstBuilder { */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { ctx.getText.toUpperCase() match { - case "GEOMETRY" => GeometryUDT + case "GEOMETRY" => GeometryUDT() case _ => super.visitPrimitiveDataType(ctx) } } diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala index 6d9f41bf4e3..66fa147bc58 100644 --- a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -44,7 +44,7 @@ class GeoPackageReaderTest extends TestBaseScala with Matchers { val expectedFeatureSchema = StructType( Seq( StructField("id", IntegerType, true), - StructField("geometry", GeometryUDT, true), + StructField("geometry", GeometryUDT(), true), StructField("text", StringType, true), StructField("real", DoubleType, true), StructField("boolean", BooleanType, true), diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala index 421890c7001..01306c1b452 100644 --- a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -120,8 +120,8 @@ class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { val schema = StructType( Seq( StructField("id", IntegerType, nullable = false), - StructField("geom_column_1", GeometryUDT, nullable = false), - StructField("geomColumn2", GeometryUDT, nullable = false))) + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index a0cadc89787..275bc3282f9 100644 --- a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -710,7 +710,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { Seq( StructField("osm_id", StringType), StructField("code2", LongType), - StructField("geometry", GeometryUDT))) + StructField("geometry", GeometryUDT()))) val shapefileDf = sparkSession.read .format("shapefile") .schema(customSchema) diff --git a/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index c0a2d8f260d..62733169288 100644 --- a/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -116,7 +116,7 @@ object ScalarUDF { pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, accumulator = null), - dataType = GeometryUDT, + dataType = GeometryUDT(), pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, udfDeterministic = true) } From 07ee40d9adb8200fbaf39a0f20b21ec3c13764a8 Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 09:08:44 -0800 Subject: [PATCH 05/14] fix: split pyspark dep with python_version markers to avoid resolver failure on Python <3.10 --- python/pyproject.toml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 440b42642e3..5d2237991e9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -37,13 +37,17 @@ dependencies = [ ] [project.optional-dependencies] -spark = ["pyspark>=3.4.0,<4.2.0"] +spark = [ + "pyspark>=3.4.0,<4.1.0; python_version < '3.10'", + "pyspark>=3.4.0,<4.2.0; python_version >= '3.10'", +] pydeck-map = ["geopandas", "pydeck==0.8.0"] kepler-map = ["geopandas", "keplergl==0.3.2"] flink = ["apache-flink>=1.19.0"] db = ["sedonadb[geopandas]; python_version >= '3.9'"] all = [ - "pyspark>=3.4.0,<4.2.0", + "pyspark>=3.4.0,<4.1.0; python_version < '3.10'", + "pyspark>=3.4.0,<4.2.0; python_version >= '3.10'", "geopandas", "pydeck==0.8.0", "keplergl==0.3.2", @@ -71,7 +75,8 @@ dev = [ # cannot set geopandas>=0.14.4 since it doesn't support python 3.8, so we pin fiona to <1.10.0 "fiona<1.10.0", "pyarrow", - "pyspark>=3.4.0,<4.2.0", + "pyspark>=3.4.0,<4.1.0; python_version < '3.10'", + "pyspark>=3.4.0,<4.2.0; python_version >= '3.10'", "keplergl==0.3.2", "pydeck==0.8.0", "pystac==1.5.0", From 542d447012d37a7d626e93b4448269cdddf6f35d Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 09:16:58 -0800 Subject: [PATCH 06/14] fix: add toString override to UDT classes and disable fail-fast in CI --- .github/workflows/java.yml | 2 +- .github/workflows/python.yml | 1 + .github/workflows/r.yml | 2 +- .../org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala | 2 ++ .../scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala | 2 ++ .../scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala | 2 ++ 6 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 027308dde73..15a15983684 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -59,7 +59,7 @@ jobs: build: runs-on: ubuntu-22.04 strategy: - fail-fast: true + fail-fast: false matrix: include: - spark: 4.1.1 diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 0b7d524e0c5..988e5bcaec6 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -58,6 +58,7 @@ jobs: build: runs-on: ubuntu-22.04 strategy: + fail-fast: false matrix: include: - spark: '4.1.1' diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index 14565278fc0..12f27c0bf59 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -57,7 +57,7 @@ jobs: build: runs-on: ubuntu-22.04 strategy: - fail-fast: true + fail-fast: false matrix: spark: [3.4.0, 3.5.0] hadoop: [3] diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala index 987ef9c229d..969672a2c02 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala @@ -52,6 +52,8 @@ class GeographyUDT extends UserDefinedType[Geography] { } override def hashCode(): Int = userClass.hashCode() + + override def toString: String = "GeographyUDT" } case object GeographyUDT diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala index 59b5b82d6a5..ae1772e617a 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala @@ -53,6 +53,8 @@ class GeometryUDT extends UserDefinedType[Geometry] { } override def hashCode(): Int = userClass.hashCode() + + override def toString: String = "GeometryUDT" } case object GeometryUDT diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala index 77c3254eb5a..67fe763129c 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala @@ -56,6 +56,8 @@ class RasterUDT extends UserDefinedType[GridCoverage2D] { } override def hashCode(): Int = userClass.hashCode() + + override def toString: String = "RasterUDT" } case object RasterUDT extends RasterUDT with Serializable { From e3eef2d9c6108686c7ea2f5fabf37822f595053f Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 10:10:36 -0800 Subject: [PATCH 07/14] fix: force-overwrite Spark 4.1 native ST functions with Sedona's and disable geospatial in tests --- .../apache/sedona/sql/UDF/AbstractCatalog.scala | 14 +++++++------- .../org/apache/sedona/sql/adapterTestJava.java | 6 +++++- .../org/apache/sedona/sql/TestBaseScala.scala | 5 ++++- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala index 72dad004672..9b4f7dda5cf 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala @@ -85,13 +85,13 @@ abstract class AbstractCatalog { def registerAll(sparkSession: SparkSession): Unit = { val registry = sparkSession.sessionState.functionRegistry expressions.foreach { case (functionIdentifier, expressionInfo, functionBuilder) => - if (!registry.functionExists(functionIdentifier)) { - registry.registerFunction(functionIdentifier, expressionInfo, functionBuilder) - FunctionRegistry.builtin.registerFunction( - functionIdentifier, - expressionInfo, - functionBuilder) - } + // Always register Sedona functions, overwriting any Spark built-in functions + // with the same name (e.g., Spark 4.1's native ST_GeomFromWKB). + registry.registerFunction(functionIdentifier, expressionInfo, functionBuilder) + FunctionRegistry.builtin.registerFunction( + functionIdentifier, + expressionInfo, + functionBuilder) } aggregateExpressions.foreach { f => registerAggregateFunction(sparkSession, f.getClass.getSimpleName, f) diff --git a/spark/common/src/test/java/org/apache/sedona/sql/adapterTestJava.java b/spark/common/src/test/java/org/apache/sedona/sql/adapterTestJava.java index 434a2783b63..b4405730c19 100644 --- a/spark/common/src/test/java/org/apache/sedona/sql/adapterTestJava.java +++ b/spark/common/src/test/java/org/apache/sedona/sql/adapterTestJava.java @@ -55,7 +55,11 @@ public class adapterTestJava implements Serializable { public static void onceExecutedBeforeAll() { sparkSession = SedonaContext.create( - SedonaContext.builder().master("local[*]").appName("adapterTestJava").getOrCreate()); + SedonaContext.builder() + .master("local[*]") + .appName("adapterTestJava") + .config("spark.sql.geospatial.enabled", "false") + .getOrCreate()); Logger.getLogger("org").setLevel(Level.WARN); Logger.getLogger("akka").setLevel(Level.WARN); } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 2f57d310ae6..0068da6060f 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -48,7 +48,10 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { def defaultSparkConfig: Map[String, String] = Map( "spark.sql.warehouse.dir" -> (System.getProperty("user.dir") + "/target/"), "sedona.join.autoBroadcastJoinThreshold" -> "-1", - "spark.kryoserializer.buffer.max" -> "64m") + "spark.kryoserializer.buffer.max" -> "64m", + // Disable Spark 4.1+ native geospatial functions that shadow Sedona's ST functions. + // This config is ignored on Spark versions that don't have it. + "spark.sql.geospatial.enabled" -> "false") // Method to be overridden by subclasses to provide additional configurations def sparkConfig: Map[String, String] = defaultSparkConfig From 78682c690c5ab052a7fedf4d8312f35549350824 Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 10:38:49 -0800 Subject: [PATCH 08/14] fix: smart guard for function registration - skip Sedona re-registration, overwrite Spark native --- .../sedona/sql/UDF/AbstractCatalog.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala index 9b4f7dda5cf..ee62dfaa813 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala @@ -85,13 +85,20 @@ abstract class AbstractCatalog { def registerAll(sparkSession: SparkSession): Unit = { val registry = sparkSession.sessionState.functionRegistry expressions.foreach { case (functionIdentifier, expressionInfo, functionBuilder) => - // Always register Sedona functions, overwriting any Spark built-in functions - // with the same name (e.g., Spark 4.1's native ST_GeomFromWKB). - registry.registerFunction(functionIdentifier, expressionInfo, functionBuilder) - FunctionRegistry.builtin.registerFunction( - functionIdentifier, - expressionInfo, - functionBuilder) + val shouldRegister = registry.lookupFunction(functionIdentifier) match { + case Some(existingInfo) => + // Skip if Sedona already registered this function (e.g., SedonaContext.create called + // twice). Overwrite if it's a Spark native function (e.g., Spark 4.1's ST_GeomFromWKB). + !existingInfo.getClassName.startsWith("org.apache.sedona.") + case None => true + } + if (shouldRegister) { + registry.registerFunction(functionIdentifier, expressionInfo, functionBuilder) + FunctionRegistry.builtin.registerFunction( + functionIdentifier, + expressionInfo, + functionBuilder) + } } aggregateExpressions.foreach { f => registerAggregateFunction(sparkSession, f.getClass.getSimpleName, f) From 660154af76572f3348712feac73a7088f29455c5 Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 11:00:24 -0800 Subject: [PATCH 09/14] fix: use python_version >= 3.10 for all Spark 4.x in Python CI --- .github/workflows/python.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 988e5bcaec6..2dd5414e9c2 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -117,9 +117,9 @@ jobs: fi if [ "${SPARK_VERSION:0:1}" == "4" ]; then - # Spark 4.x requires Python 3.9+, and we remove flink since it conflicts with pyspark 4.x + # Spark 4.x requires Python 3.10+, and we remove flink since it conflicts with pyspark 4.x uv remove apache-flink --optional flink - uv add "pyspark==${SPARK_VERSION}; python_version >= '3.9'" + uv add "pyspark==${SPARK_VERSION}; python_version >= '3.10'" else # Install specific pyspark version matching matrix uv add pyspark==${SPARK_VERSION} From 7d65d1b21422dd3ec2c2d4d0a61953dddbfe62d1 Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 11:27:58 -0800 Subject: [PATCH 10/14] fix: alias DataFrames in OsmReaderTest to resolve ambiguous self-join on Spark 4.1 --- .../org/apache/sedona/sql/OsmReaderTest.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/OsmReaderTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/OsmReaderTest.scala index 8c61ae0785d..817a5db0b91 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/OsmReaderTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/OsmReaderTest.scala @@ -176,28 +176,28 @@ class OsmReaderTest extends TestBaseScala with Matchers { "smoothness" -> "excellent")) // make sure the nodes match with refs - val nodes = osmData.where("kind == 'node'") - val ways = osmData.where("kind == 'way'") - val relations = osmData.where("kind == 'relation'") + val nodes = osmData.where("kind == 'node'").as("n") + val ways = osmData.where("kind == 'way'").as("w") + val relations = osmData.where("kind == 'relation'").as("rel") ways .selectExpr("explode(refs) AS ref") - .alias("w") - .join(nodes, col("w.ref") === nodes("id")) + .alias("w2") + .join(nodes, col("w2.ref") === col("n.id")) .count() shouldEqual (47812) ways .selectExpr("explode(refs) AS ref", "id") - .alias("w") - .join(nodes, col("w.ref") === nodes("id")) - .groupBy("w.id") + .alias("w2") + .join(nodes, col("w2.ref") === col("n.id")) + .groupBy("w2.id") .count() .count() shouldEqual (ways.count()) relations .selectExpr("explode(refs) AS ref", "id") .alias("r") - .join(nodes, col("r.ref") === nodes("id")) + .join(nodes, col("r.ref") === col("n.id")) .groupBy("r.id") .count() .count() shouldEqual (162) @@ -205,7 +205,7 @@ class OsmReaderTest extends TestBaseScala with Matchers { relations .selectExpr("explode(refs) AS ref", "id") .alias("r") - .join(ways, col("r.ref") === ways("id")) + .join(ways, col("r.ref") === col("w.id")) .groupBy("r.id") .count() .count() shouldEqual (261) From 391ad36284678b591963a61e92538db96811264b Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 11:34:11 -0800 Subject: [PATCH 11/14] fix: replace commons-collections 3.x with Java stdlib Spark 4.1 no longer provides commons-collections 3.x transitively. Replace FilterIterator with Java 8 stream filtering in DuplicatesFilter, and IteratorUtils.toList with StreamSupport in the test. --- .../core/joinJudgement/DuplicatesFilter.java | 13 +++++++------ .../GenericUniquePartitionerTest.java | 16 +++++++++++++--- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/DuplicatesFilter.java b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/DuplicatesFilter.java index 7b28ec2fbed..1e719a5e13a 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/DuplicatesFilter.java +++ b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/DuplicatesFilter.java @@ -20,7 +20,9 @@ import java.util.Iterator; import java.util.List; -import org.apache.commons.collections.iterators.FilterIterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.StreamSupport; import org.apache.commons.lang3.tuple.Pair; import org.apache.log4j.LogManager; import org.apache.log4j.Logger; @@ -57,11 +59,10 @@ public Iterator> call(Integer partitionId, Iterator> geome final List partitionExtents = dedupParamsBroadcast.getValue().getPartitionExtents(); if (partitionId < partitionExtents.size()) { HalfOpenRectangle extent = new HalfOpenRectangle(partitionExtents.get(partitionId)); - return new FilterIterator( - geometryPair, - p -> - !GeomUtils.isDuplicate( - ((Pair) p).getLeft(), ((Pair) p).getRight(), extent)); + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize(geometryPair, Spliterator.ORDERED), false) + .filter(p -> !GeomUtils.isDuplicate(p.getLeft(), p.getRight(), extent)) + .iterator(); } else { log.warn("Didn't find partition extent for this partition: " + partitionId); return geometryPair; diff --git a/spark/common/src/test/java/org/apache/sedona/core/spatialPartitioning/GenericUniquePartitionerTest.java b/spark/common/src/test/java/org/apache/sedona/core/spatialPartitioning/GenericUniquePartitionerTest.java index 1df270c0a01..02b9069a13e 100644 --- a/spark/common/src/test/java/org/apache/sedona/core/spatialPartitioning/GenericUniquePartitionerTest.java +++ b/spark/common/src/test/java/org/apache/sedona/core/spatialPartitioning/GenericUniquePartitionerTest.java @@ -22,7 +22,11 @@ import java.util.ArrayList; import java.util.Iterator; -import org.apache.commons.collections.IteratorUtils; +import java.util.List; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.junit.Test; import org.locationtech.jts.geom.Envelope; import org.locationtech.jts.geom.Geometry; @@ -52,10 +56,16 @@ public void testUniquePartition() throws Exception { partitioner.placeObject(factory.toGeometry(definitelyHasMultiplePartitions)); // Because the geometry is not completely contained by any of the partitions, // it also gets placed in the overflow partition (hence 5, not 4) - assertEquals(5, IteratorUtils.toList(placedWithDuplicates).size()); + assertEquals(5, iteratorToList(placedWithDuplicates).size()); Iterator> placedWithoutDuplicates = uniquePartitioner.placeObject(factory.toGeometry(definitelyHasMultiplePartitions)); - assertEquals(1, IteratorUtils.toList(placedWithoutDuplicates).size()); + assertEquals(1, iteratorToList(placedWithoutDuplicates).size()); + } + + private List iteratorToList(Iterator iterator) { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize(iterator, Spliterator.ORDERED), false) + .collect(Collectors.toList()); } } From 9efa26f5816513967887836f768d4ec903201457 Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 17:32:26 -0800 Subject: [PATCH 12/14] fix: override parsePlanWithParameters for Spark 4.1 parser compatibility In Spark 4.1, SparkSqlParser introduced an override for parsePlanWithParameters that bypasses parsePlan entirely. SedonaSqlParser only overrode parsePlan, so its SQL parsing interception was never invoked on Spark 4.1. Fix by also overriding parsePlanWithParameters in SedonaSqlParser to use the same delegate-first, Sedona-fallback pattern. Also disable spark.sql.geospatial.enabled in tests to prevent Spark 4.1+ native geospatial functions from shadowing Sedona's ST functions. --- .../sedona/sql/parser/SedonaSqlParser.scala | 39 +++++++++++++------ .../org/apache/sedona/sql/TestBaseScala.scala | 2 + 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala index 54fb074eb05..31fbfd4932f 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala @@ -18,7 +18,7 @@ */ package org.apache.sedona.sql.parser -import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.{ParameterContext, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlParser @@ -27,23 +27,40 @@ class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { // The parser builder for the Sedona SQL AST val parserBuilder = new SedonaSqlAstBuilder + private def sedonaFallback(sqlText: String): LogicalPlan = + parse(sqlText) { parser => + parserBuilder.visit(parser.singleStatement()) + }.asInstanceOf[LogicalPlan] + /** - * Parse the SQL text and return the logical plan. This method first attempts to use the - * delegate parser to parse the SQL text. If the delegate parser fails (throws an exception), it - * falls back to using the Sedona SQL parser. + * Parse the SQL text with parameters and return the logical plan. In Spark 4.1+, + * SparkSqlParser overrides parsePlanWithParameters to bypass parsePlan, so we must override + * this method to intercept the parse flow. * - * @param sqlText - * The SQL text to be parsed. - * @return - * The parsed logical plan. + * This method first attempts to use the delegate parser. If the delegate parser fails (throws + * an exception), it falls back to using the Sedona SQL parser. + */ + override def parsePlanWithParameters( + sqlText: String, + paramContext: ParameterContext): LogicalPlan = + try { + delegate.parsePlanWithParameters(sqlText, paramContext) + } catch { + case _: Exception => + sedonaFallback(sqlText) + } + + /** + * Parse the SQL text and return the logical plan. Note: in Spark 4.1+, SparkSession.sql() + * calls parsePlanWithParameters (overridden above), which no longer delegates to parsePlan. + * This override is kept as a defensive measure in case any third-party code or future Spark + * internals call parsePlan directly on the parser instance. */ override def parsePlan(sqlText: String): LogicalPlan = try { delegate.parsePlan(sqlText) } catch { case _: Exception => - parse(sqlText) { parser => - parserBuilder.visit(parser.singleStatement()) - }.asInstanceOf[LogicalPlan] + sedonaFallback(sqlText) } } diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 28943ff11da..2e3e9742222 100644 --- a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -47,6 +47,8 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { .config("sedona.join.autoBroadcastJoinThreshold", "-1") .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) + // Disable Spark 4.1+ native geospatial functions that shadow Sedona's ST functions + .config("spark.sql.geospatial.enabled", "false") .getOrCreate() val sparkSessionMinio = SedonaContext From b4c3c1730fca5169de310bab7830cbfc5600ae2e Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Fri, 13 Feb 2026 21:59:04 -0800 Subject: [PATCH 13/14] fix: handle bytes type for BinaryType columns in PySpark 4.1 tests PySpark 4.1 returns BinaryType columns as bytes instead of bytearray. Update the isinstance check to handle both types so ST_AsBinary and ST_AsEWKB test results are properly hex-encoded before comparison. --- python/tests/sql/test_dataframe_api.py | 2 +- .../apache/sedona/sql/parser/SedonaSqlParser.scala | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 4eb8ee02888..81dcd5055a9 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -1790,7 +1790,7 @@ def test_dataframe_function( elif isinstance(actual_result, Geography): # self.assert_geometry_almost_equal(expected_result, actual_result.geometry) return - elif isinstance(actual_result, bytearray): + elif isinstance(actual_result, (bytes, bytearray)): actual_result = actual_result.hex() elif isinstance(actual_result, Row): actual_result = { diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala index 31fbfd4932f..cefd1487a47 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala @@ -33,9 +33,9 @@ class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { }.asInstanceOf[LogicalPlan] /** - * Parse the SQL text with parameters and return the logical plan. In Spark 4.1+, - * SparkSqlParser overrides parsePlanWithParameters to bypass parsePlan, so we must override - * this method to intercept the parse flow. + * Parse the SQL text with parameters and return the logical plan. In Spark 4.1+, SparkSqlParser + * overrides parsePlanWithParameters to bypass parsePlan, so we must override this method to + * intercept the parse flow. * * This method first attempts to use the delegate parser. If the delegate parser fails (throws * an exception), it falls back to using the Sedona SQL parser. @@ -51,9 +51,9 @@ class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { } /** - * Parse the SQL text and return the logical plan. Note: in Spark 4.1+, SparkSession.sql() - * calls parsePlanWithParameters (overridden above), which no longer delegates to parsePlan. - * This override is kept as a defensive measure in case any third-party code or future Spark + * Parse the SQL text and return the logical plan. Note: in Spark 4.1+, SparkSession.sql() calls + * parsePlanWithParameters (overridden above), which no longer delegates to parsePlan. This + * override is kept as a defensive measure in case any third-party code or future Spark * internals call parsePlan directly on the parser instance. */ override def parsePlan(sqlText: String): LogicalPlan = From f81d5fec00b60231392cece31921fb870078cc6e Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Sat, 14 Feb 2026 01:03:46 -0800 Subject: [PATCH 14/14] fix: re-enable fail-fast in CI workflows --- .github/workflows/java.yml | 2 +- .github/workflows/python.yml | 1 - .github/workflows/r.yml | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 15a15983684..027308dde73 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -59,7 +59,7 @@ jobs: build: runs-on: ubuntu-22.04 strategy: - fail-fast: false + fail-fast: true matrix: include: - spark: 4.1.1 diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 2dd5414e9c2..10989cd364a 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -58,7 +58,6 @@ jobs: build: runs-on: ubuntu-22.04 strategy: - fail-fast: false matrix: include: - spark: '4.1.1' diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index 12f27c0bf59..14565278fc0 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -57,7 +57,7 @@ jobs: build: runs-on: ubuntu-22.04 strategy: - fail-fast: false + fail-fast: true matrix: spark: [3.4.0, 3.5.0] hadoop: [3]