The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala new file mode 100644 index 00000000000..306b1df4f6c --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartition.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.Partition +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.execution.datasources.PartitionedFile + +case class ShapefilePartition(index: Int, files: Array[PartitionedFile]) + extends Partition + with InputPartition diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 00000000000..301d63296fb --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = file.filePath.toPath.getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> file.filePath.toPath + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 00000000000..5a28af6d66f --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: ShapefilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 00000000000..ebc02fae85a --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 00000000000..afdface8942 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = partitionedFile.filePath.toPath + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(ShapefilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 00000000000..e5135e381df --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 00000000000..7db6bb8d1f4 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 00000000000..fd6d1e83827 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis.Resolver +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT()), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT()) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromString(UTF8String.fromBytes(bytes)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala new file mode 100644 index 00000000000..b56ed11c875 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlAstBuilder.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.parser + +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.execution.SparkSqlAstBuilder +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.DataType + +class SedonaSqlAstBuilder extends SparkSqlAstBuilder { + + /** + * Override the method to handle the geometry data type + * @param ctx + * @return + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = { + ctx.getText.toUpperCase() match { + case "GEOMETRY" => GeometryUDT() + case _ => super.visitPrimitiveDataType(ctx) + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala new file mode 100644 index 00000000000..cefd1487a47 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.parser + +import org.apache.spark.sql.catalyst.parser.{ParameterContext, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkSqlParser + +class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { + + // The parser builder for the Sedona SQL AST + val parserBuilder = new SedonaSqlAstBuilder + + private def sedonaFallback(sqlText: String): LogicalPlan = + parse(sqlText) { parser => + parserBuilder.visit(parser.singleStatement()) + }.asInstanceOf[LogicalPlan] + + /** + * Parse the SQL text with parameters and return the logical plan. In Spark 4.1+, SparkSqlParser + * overrides parsePlanWithParameters to bypass parsePlan, so we must override this method to + * intercept the parse flow. + * + * This method first attempts to use the delegate parser. If the delegate parser fails (throws + * an exception), it falls back to using the Sedona SQL parser. + */ + override def parsePlanWithParameters( + sqlText: String, + paramContext: ParameterContext): LogicalPlan = + try { + delegate.parsePlanWithParameters(sqlText, paramContext) + } catch { + case _: Exception => + sedonaFallback(sqlText) + } + + /** + * Parse the SQL text and return the logical plan. Note: in Spark 4.1+, SparkSession.sql() calls + * parsePlanWithParameters (overridden above), which no longer delegates to parsePlan. This + * override is kept as a defensive measure in case any third-party code or future Spark + * internals call parsePlan directly on the parser instance. + */ + override def parsePlan(sqlText: String): LogicalPlan = + try { + delegate.parsePlan(sqlText) + } catch { + case _: Exception => + sedonaFallback(sqlText) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala new file mode 100644 index 00000000000..43e1ababb7d --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Data source for reading GeoParquet metadata. This could be accessed using the `spark.read` + * interface: + * {{{ + * val df = spark.read.format("geoparquet.metadata").load("path/to/geoparquet") + * }}} + */ +class GeoParquetMetadataDataSource extends FileDataSourceV2 with DataSourceRegister { + override val shortName: String = "geoparquet.metadata" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + None, + fallbackFileFormat) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala new file mode 100644 index 00000000000..683160e93b2 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.ParquetReadOptions +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods.{compact, render} + +case class GeoParquetMetadataPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: FileSourceOptions, + filters: Seq[Filter]) + extends FilePartitionReaderFactory { + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { + val iter = GeoParquetMetadataPartitionReaderFactory.readFile( + broadcastedConf.value.value, + partitionedFile, + readDataSchema) + val fileReader = new PartitionReaderFromIterator[InternalRow](iter) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFile.partitionValues) + } +} + +object GeoParquetMetadataPartitionReaderFactory { + private def readFile( + configuration: Configuration, + partitionedFile: PartitionedFile, + readDataSchema: StructType): Iterator[InternalRow] = { + + val inputFile = HadoopInputFile.fromPath(partitionedFile.toPath, configuration) + val inputStream = inputFile.newStream() + + val footer = ParquetFileReader + .readFooter(inputFile, ParquetReadOptions.builder().build(), inputStream) + + val filePath = partitionedFile.toPath.toString + val metadata = footer.getFileMetaData.getKeyValueMetaData + val row = GeoParquetMetaData.parseKeyValueMetaData(metadata) match { + case Some(geo) => + val geoColumnsMap = geo.columns.map { case (columnName, columnMetadata) => + implicit val formats: org.json4s.Formats = DefaultFormats + import org.json4s.jackson.Serialization + val columnMetadataFields: Array[Any] = Array( + UTF8String.fromString(columnMetadata.encoding), + new GenericArrayData(columnMetadata.geometryTypes.map(UTF8String.fromString).toArray), + new GenericArrayData(columnMetadata.bbox.toArray), + columnMetadata.crs + .map(projjson => UTF8String.fromString(compact(render(projjson)))) + .getOrElse(UTF8String.fromString("")), + columnMetadata.covering + .map(covering => UTF8String.fromString(Serialization.write(covering))) + .orNull) + val columnMetadataStruct = new GenericInternalRow(columnMetadataFields) + UTF8String.fromString(columnName) -> columnMetadataStruct + } + val fields: Array[Any] = Array( + UTF8String.fromString(filePath), + UTF8String.fromString(geo.version.orNull), + UTF8String.fromString(geo.primaryColumn), + ArrayBasedMapData(geoColumnsMap)) + new GenericInternalRow(fields) + case None => + // Not a GeoParquet file, return a row with null metadata values. + val fields: Array[Any] = Array(UTF8String.fromString(filePath), null, null, null) + new GenericInternalRow(fields) + } + Iterator(pruneBySchema(row, GeoParquetMetadataTable.schema, readDataSchema)) + } + + private def pruneBySchema( + row: InternalRow, + schema: StructType, + readDataSchema: StructType): InternalRow = { + // Projection push down for nested fields is not enabled, so this very simple implementation is enough. + val values: Array[Any] = readDataSchema.fields.map { field => + val index = schema.fieldIndex(field.name) + row.get(index, field.dataType) + } + new GenericInternalRow(values) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala new file mode 100644 index 00000000000..d7719d87dad --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.collection.JavaConverters._ + +case class GeoParquetMetadataScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + val fileSourceOptions = new FileSourceOptions(caseSensitiveMap) + GeoParquetMetadataPartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + fileSourceOptions, + pushedFilters) + } + + override def isSplitable(path: Path): Boolean = false + + override def getFileUnSplittableReason(path: Path): String = + "Reading parquet file metadata does not require splitting the file" +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala new file mode 100644 index 00000000000..c60369e1087 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class GeoParquetMetadataScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + override def build(): Scan = { + GeoParquetMetadataScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala new file mode 100644 index 00000000000..845764fae55 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class GeoParquetMetadataTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def formatName: String = "GeoParquet Metadata" + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + Some(GeoParquetMetadataTable.schema) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new GeoParquetMetadataScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) +} + +object GeoParquetMetadataTable { + private val columnMetadataType = StructType( + Seq( + StructField("encoding", StringType, nullable = true), + StructField("geometry_types", ArrayType(StringType), nullable = true), + StructField("bbox", ArrayType(DoubleType), nullable = true), + StructField("crs", StringType, nullable = true), + StructField("covering", StringType, nullable = true))) + + private val columnsType = MapType(StringType, columnMetadataType, valueContainsNull = false) + + val schema: StructType = StructType( + Seq( + StructField("path", StringType, nullable = false), + StructField("version", StringType, nullable = true), + StructField("primary_column", StringType, nullable = true), + StructField("columns", columnsType, nullable = true))) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala new file mode 100644 index 00000000000..30a3e652cf2 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowEvalPythonExec.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import scala.jdk.CollectionConverters._ + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata +import org.apache.spark.sql.types.StructType + +/** + * A physical plan that evaluates a [[PythonUDF]]. + */ +case class SedonaArrowEvalPythonExec( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: SparkPlan, + evalType: Int) + extends EvalPythonExec + with PythonSQLMetrics { + + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + private[this] val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } + + override protected def evaluatorFactory: EvalPythonEvaluatorFactory = { + new SedonaArrowEvalPythonEvaluatorFactory( + child.output, + udfs, + output, + conf.arrowMaxRecordsPerBatch, + evalType, + conf.sessionLocalTimeZone, + conf.arrowUseLargeVarTypes, + ArrowPythonRunner.getPythonRunnerConfMap(conf), + pythonMetrics, + jobArtifactUUID, + sessionUUID, + conf.pythonUDFProfiler) + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + +class SedonaArrowEvalPythonEvaluatorFactory( + childOutput: Seq[Attribute], + udfs: Seq[PythonUDF], + output: Seq[Attribute], + batchSize: Int, + evalType: Int, + sessionLocalTimeZone: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + sessionUUID: Option[String], + profiler: Option[String]) + extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { + + override def evaluate( + funcs: Seq[(ChainedPythonFunctions, Long)], + argMetas: Array[Array[ArgumentMetadata]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + val batchIter = Iterator(iter) + + val pyRunner = new ArrowPythonWithNamedArgumentRunner( + funcs, + evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, + argMetas, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + sessionUUID, + profiler) with BatchedPythonArrowInput + val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context) + + columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala new file mode 100644 index 00000000000..3d3301580cc --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionSet, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Subquery} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_UDF + +import scala.collection.mutable + +// That rule extracts scalar Python UDFs, currently Apache Spark has +// assert on types which blocks using the vectorized udfs with geometry type +class ExtractSedonaUDFRule extends Rule[LogicalPlan] with Logging { + + private def hasScalarPythonUDF(e: Expression): Boolean = { + e.exists(PythonUDF.isScalarPythonUDF) + } + + @scala.annotation.tailrec + private def canEvaluateInPython(e: PythonUDF): Boolean = { + e.children match { + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) + case children => !children.exists(hasScalarPythonUDF) + } + } + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && e + .asInstanceOf[PythonUDF] + .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF + } + + private def collectEvaluableUDFsFromExpressions( + expressions: Seq[Expression]): Seq[PythonUDF] = { + + var firstVisitedScalarUDFEvalType: Option[Int] = None + + def canChainUDF(evalType: Int): Boolean = { + evalType == firstVisitedScalarUDFEvalType.get + } + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF + if isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && firstVisitedScalarUDFEvalType.isEmpty => + firstVisitedScalarUDFEvalType = Some(udf.evalType) + Seq(udf) + case udf: PythonUDF + if isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && canChainUDF(udf.evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs) + } + + expressions.flatMap(collectEvaluableUDFs) + } + + private var hasFailedBefore: Boolean = false + + def apply(plan: LogicalPlan): LogicalPlan = plan match { + case s: Subquery if s.correlated => plan + + case _ => + try { + plan.transformUpWithPruning(_.containsPattern(PYTHON_UDF)) { + case p: SedonaArrowEvalPython => p + + case plan: LogicalPlan => extract(plan) + } + } catch { + case e: Throwable => + if (!hasFailedBefore) { + log.warn( + s"Vectorized UDF feature won't be available due to plan transformation error.") + log.warn( + s"Failed to extract Sedona UDFs from plan: ${plan.treeString}\n" + + s"Exception: ${e.getMessage}", + e) + hasFailedBefore = true + } + plan + } + } + + private def canonicalizeDeterministic(u: PythonUDF) = { + if (u.deterministic) { + u.canonicalized.asInstanceOf[PythonUDF] + } else { + u + } + } + + private def extract(plan: LogicalPlan): LogicalPlan = { + val udfs = ExpressionSet(collectEvaluableUDFsFromExpressions(plan.expressions)) + .filter(udf => udf.references.subsetOf(plan.inputSet)) + .toSeq + .asInstanceOf[Seq[PythonUDF]] + + udfs match { + case Seq() => plan + case _ => resolveUDFs(plan, udfs) + } + } + + def resolveUDFs(plan: LogicalPlan, udfs: Seq[PythonUDF]): LogicalPlan = { + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + + val newChildren = adjustAttributeMap(plan, udfs, attributeMap) + + udfs.map(canonicalizeDeterministic).filterNot(attributeMap.contains).foreach { udf => + throw new IllegalStateException( + s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } + + val rewritten = plan.withNewChildren(newChildren).transformExpressions { case p: PythonUDF => + attributeMap.getOrElse(canonicalizeDeterministic(p), p) + } + + val newPlan = extract(rewritten) + if (newPlan.output != plan.output) { + Project(plan.output, newPlan) + } else { + newPlan + } + } + + def adjustAttributeMap( + plan: LogicalPlan, + udfs: Seq[PythonUDF], + attributeMap: mutable.HashMap[PythonUDF, Expression]): Seq[LogicalPlan] = { + plan.children.map { child => + val validUdfs = udfs.filter { udf => + udf.references.subsetOf(child.outputSet) + } + + if (validUdfs.nonEmpty) { + require( + validUdfs.forall(isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") + + val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() + } + + val evalTypes = validUdfs.map(_.evalType).toSet + if (evalTypes.size != 1) { + throw new IllegalStateException( + "Expected udfs have the same evalType but got different evalTypes: " + + evalTypes.mkString(",")) + } + val evalType = evalTypes.head + val evaluation = evalType match { + case PythonEvalType.SQL_SCALAR_SEDONA_UDF => + SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType) + case _ => + throw new IllegalStateException("Unexpected UDF evalType") + } + + attributeMap ++= validUdfs.map(canonicalizeDeterministic).zip(resultAttrs) + evaluation + } else { + child + } + } + } +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala new file mode 100644 index 00000000000..7600ece5079 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython, LogicalPlan} + +case class SedonaArrowEvalPython( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: LogicalPlan, + evalType: Int) + extends BaseEvalPython { + override protected def withNewChildInternal(newChild: LogicalPlan): SedonaArrowEvalPython = + copy(child = newChild) +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala new file mode 100644 index 00000000000..21e389ee567 --- /dev/null +++ b/spark/spark-4.1/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.SedonaArrowEvalPythonExec +import org.apache.spark.sql.types.StructType + +import scala.collection.JavaConverters.asScalaIteratorConverter + +// We use custom Strategy to avoid Apache Spark assert on types, we +// can consider extending this to support other engines working with +// arrow data +class SedonaArrowStrategy extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case SedonaArrowEvalPython(udfs, output, child, evalType) => + SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil + case _ => Nil + } +} diff --git a/spark/spark-4.1/src/test/resources/log4j2.properties b/spark/spark-4.1/src/test/resources/log4j2.properties new file mode 100644 index 00000000000..683ecd32f25 --- /dev/null +++ b/spark/spark-4.1/src/test/resources/log4j2.properties @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Set everything to be logged to the file target/unit-tests.log +rootLogger.level = info +rootLogger.appenderRef.file.ref = File + +appender.file.type = File +appender.file.name = File +appender.file.fileName = target/unit-tests.log +appender.file.append = true +appender.file.layout.type = PatternLayout +appender.file.layout.pattern = %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex + +# Ignore messages below warning level from Jetty, because it's a bit verbose +logger.jetty.name = org.sparkproject.jetty +logger.jetty.level = warn diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala new file mode 100644 index 00000000000..66fa147bc58 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import io.minio.{MakeBucketArgs, MinioClient, PutObjectArgs} +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{BinaryType, BooleanType, DateType, DoubleType, IntegerType, StringType, StructField, StructType, TimestampType} +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.testcontainers.containers.MinIOContainer + +import java.io.FileInputStream +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +class GeoPackageReaderTest extends TestBaseScala with Matchers { + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + import sparkSession.implicits._ + + val path: String = resourceFolder + "geopackage/example.gpkg" + val polygonsPath: String = resourceFolder + "geopackage/features.gpkg" + val rasterPath: String = resourceFolder + "geopackage/raster.gpkg" + val wktReader = new org.locationtech.jts.io.WKTReader() + val wktWriter = new org.locationtech.jts.io.WKTWriter() + + val expectedFeatureSchema = StructType( + Seq( + StructField("id", IntegerType, true), + StructField("geometry", GeometryUDT(), true), + StructField("text", StringType, true), + StructField("real", DoubleType, true), + StructField("boolean", BooleanType, true), + StructField("blob", BinaryType, true), + StructField("integer", IntegerType, true), + StructField("text_limited", StringType, true), + StructField("blob_limited", BinaryType, true), + StructField("date", DateType, true), + StructField("datetime", TimestampType, true))) + + describe("Reading GeoPackage metadata") { + it("should read GeoPackage metadata") { + val df = sparkSession.read + .format("geopackage") + .option("showMetadata", "true") + .load(path) + + df.where("data_type = 'tiles'").show(false) + + df.count shouldEqual 34 + } + } + + describe("Reading Vector data") { + it("should read GeoPackage - point1") { + val df = readFeatureData("point1") + df.schema shouldEqual expectedFeatureSchema + + df.count() shouldEqual 4 + + val firstElement = df.collectAsList().get(0).toSeq + + val expectedValues = Seq( + 1, + wktReader.read(POINT_1), + "BIT Systems", + 4519.866024037493, + true, + Array(48, 99, 57, 54, 49, 56, 55, 54, 45, 98, 102, 100, 52, 45, 52, 102, 52, 48, 45, 97, + 49, 102, 101, 45, 55, 49, 55, 101, 57, 100, 50, 98, 48, 55, 98, 101), + 3, + "bcd5a36f-16dc-4385-87be-b40353848597", + Array(49, 50, 53, 50, 97, 99, 98, 52, 45, 57, 54, 54, 52, 45, 52, 101, 51, 50, 45, 57, 54, + 100, 101, 45, 56, 48, 54, 101, 101, 48, 101, 101, 49, 102, 57, 48), + Date.valueOf("2023-09-19"), + Timestamp.valueOf("2023-09-19 11:24:15.695")) + + firstElement should contain theSameElementsAs expectedValues + } + + it("should read GeoPackage - line1") { + val df = readFeatureData("line1") + .withColumn("datetime", expr("from_utc_timestamp(datetime, 'UTC')")) + + df.schema shouldEqual expectedFeatureSchema + + df.count() shouldEqual 3 + + val firstElement = df.collectAsList().get(0).toSeq + + firstElement should contain theSameElementsAs Seq( + 1, + wktReader.read(LINESTRING_1), + "East Lockheed Drive", + 1990.5159635296877, + false, + Array(54, 97, 98, 100, 98, 51, 97, 56, 45, 54, 53, 101, 48, 45, 52, 55, 48, 54, 45, 56, + 50, 52, 48, 45, 51, 57, 48, 55, 99, 50, 102, 102, 57, 48, 99, 55), + 1, + "13dd91dc-3b7d-4d8d-a0ca-b3afb8e31c3d", + Array(57, 54, 98, 102, 56, 99, 101, 56, 45, 102, 48, 54, 49, 45, 52, 55, 99, 48, 45, 97, + 98, 48, 101, 45, 97, 99, 50, 52, 100, 98, 50, 97, 102, 50, 50, 54), + Date.valueOf("2023-09-19"), + Timestamp.valueOf("2023-09-19 11:24:15.716")) + } + + it("should read GeoPackage - polygon1") { + val df = readFeatureData("polygon1") + df.count shouldEqual 3 + df.schema shouldEqual expectedFeatureSchema + + df.select("geometry").collectAsList().get(0).toSeq should contain theSameElementsAs Seq( + wktReader.read(POLYGON_1)) + } + + it("should read GeoPackage - geometry1") { + val df = readFeatureData("geometry1") + df.count shouldEqual 10 + df.schema shouldEqual expectedFeatureSchema + + df.selectExpr("ST_ASTEXT(geometry)") + .as[String] + .collect() should contain theSameElementsAs Seq( + POINT_1, + POINT_2, + POINT_3, + POINT_4, + LINESTRING_1, + LINESTRING_2, + LINESTRING_3, + POLYGON_1, + POLYGON_2, + POLYGON_3) + } + + it("should read polygon with envelope data") { + val tables = Table( + ("tableName", "expectedCount"), + ("GB_Hex_5km_GS_CompressibleGround_v8", 4233), + ("GB_Hex_5km_GS_Landslides_v8", 4228), + ("GB_Hex_5km_GS_RunningSand_v8", 4233), + ("GB_Hex_5km_GS_ShrinkSwell_v8", 4233), + ("GB_Hex_5km_GS_SolubleRocks_v8", 4295)) + + forAll(tables) { (tableName: String, expectedCount: Int) => + val df = sparkSession.read + .format("geopackage") + .option("tableName", tableName) + .load(polygonsPath) + + df.count() shouldEqual expectedCount + } + } + + it("should handle datetime fields without timezone information") { + // This test verifies the fix for DateTimeParseException when reading + // GeoPackage files with datetime fields that don't include timezone info + val testFilePath = resourceFolder + "geopackage/test_datetime_issue.gpkg" + + // Test reading the test_features table with problematic datetime formats + val df = sparkSession.read + .format("geopackage") + .option("tableName", "test_features") + .load(testFilePath) + + // The test should not throw DateTimeParseException when reading datetime fields + noException should be thrownBy { + df.select("created_at", "updated_at").collect() + } + + // Verify that datetime fields are properly parsed as TimestampType + df.schema.fields.find(_.name == "created_at").get.dataType shouldEqual TimestampType + df.schema.fields.find(_.name == "updated_at").get.dataType shouldEqual TimestampType + + // Verify that we can read the datetime values + val datetimeValues = df.select("created_at", "updated_at").collect() + datetimeValues should not be empty + + // Verify that datetime values are valid timestamps + datetimeValues.foreach { row => + val createdTimestamp = row.getAs[Timestamp]("created_at") + val updatedTimestamp = row.getAs[Timestamp]("updated_at") + createdTimestamp should not be null + updatedTimestamp should not be null + createdTimestamp.getTime should be > 0L + updatedTimestamp.getTime should be > 0L + } + + // Test showMetadata option with the same file + noException should be thrownBy { + val metadataDf = sparkSession.read + .format("geopackage") + .option("showMetadata", "true") + .load(testFilePath) + metadataDf.select("last_change").collect() + } + } + } + + describe("GeoPackage Raster Data Test") { + it("should read") { + val fractions = + Table( + ("tableName", "channelNumber", "expectedSum"), + ("point1_tiles", 4, 466591.0), + ("line1_tiles", 4, 5775976.0), + ("polygon1_tiles", 4, 1.1269871e7), + ("geometry1_tiles", 4, 2.6328442e7), + ("point2_tiles", 4, 137456.0), + ("line2_tiles", 4, 6701101.0), + ("polygon2_tiles", 4, 5.1170714e7), + ("geometry2_tiles", 4, 1.6699823e7), + ("bit_systems", 1, 6.5561879e7), + ("nga", 1, 6.8078856e7), + ("bit_systems_wgs84", 1, 7.7276934e7), + ("nga_pc", 1, 2.90590616e8), + ("bit_systems_world", 1, 7.7276934e7), + ("nga_pc_world", 1, 2.90590616e8)) + + forAll(fractions) { (tableName: String, channelNumber: Int, expectedSum: Double) => + { + val df = readFeatureData(tableName) + val calculatedSum = df + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${channelNumber}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum.collect().head shouldEqual expectedSum + } + } + } + + it("should be able to read complex raster data") { + val df = sparkSession.read + .format("geopackage") + .option("tableName", "AuroraAirportNoise") + .load(rasterPath) + + df.show(5) + + val calculatedSum = df + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${1}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum.first() shouldEqual 2.027126e7 + + val df2 = sparkSession.read + .format("geopackage") + .option("tableName", "LiquorLicenseDensity") + .load(rasterPath) + + val calculatedSum2 = df2 + .selectExpr(s"RS_SummaryStats(tile_data, 'sum', ${1}) as stats") + .selectExpr("sum(stats)") + .as[Double] + + calculatedSum2.first() shouldEqual 2.882028e7 + } + + } + + describe("Reading from S3") { + it("should be able to read files from S3") { + val container = new MinIOContainer("minio/minio:latest") + + container.start() + + val minioClient = createMinioClient(container) + val makeBucketRequest = MakeBucketArgs + .builder() + .bucket("sedona") + .build() + + minioClient.makeBucket(makeBucketRequest) + + adjustSparkSession(sparkSessionMinio, container) + + val inputPath: String = prepareFile("example.geopackage", path, minioClient) + + sparkSessionMinio.read + .format("geopackage") + .option("showMetadata", "true") + .load(inputPath) + .count shouldEqual 34 + + val df = sparkSession.read + .format("geopackage") + .option("tableName", "point1") + .load(inputPath) + + df.count shouldEqual 4 + + val inputPathLarger: String = prepareFiles((1 to 300).map(_ => path).toArray, minioClient) + + val dfLarger = sparkSessionMinio.read + .format("geopackage") + .option("tableName", "point1") + .load(inputPathLarger) + + dfLarger.count shouldEqual 300 * 4 + + container.stop() + } + } + + private def readFeatureData(tableName: String): DataFrame = { + sparkSession.read + .format("geopackage") + .option("tableName", tableName) + .load(path) + } + + private def prepareFiles(paths: Array[String], minioClient: MinioClient): String = { + val key = "geopackage" + + paths.foreach(path => { + val fis = new FileInputStream(path); + putFileIntoBucket( + "sedona", + s"${key}/${scala.util.Random.nextInt(1000000000)}.geopackage", + fis, + minioClient) + }) + + s"s3a://sedona/$key" + } + + private def prepareFile(name: String, path: String, minioClient: MinioClient): String = { + val fis = new FileInputStream(path); + putFileIntoBucket("sedona", name, fis, minioClient) + + s"s3a://sedona/$name" + } + + private val POINT_1 = "POINT (-104.801918 39.720014)" + private val POINT_2 = "POINT (-104.802987 39.717703)" + private val POINT_3 = "POINT (-104.807496 39.714085)" + private val POINT_4 = "POINT (-104.79948 39.714729)" + private val LINESTRING_1 = + "LINESTRING (-104.800614 39.720721, -104.802174 39.720726, -104.802584 39.72066, -104.803088 39.720477, -104.803474 39.720209)" + private val LINESTRING_2 = + "LINESTRING (-104.809612 39.718379, -104.806638 39.718372, -104.806236 39.718439, -104.805939 39.718536, -104.805654 39.718677, -104.803652 39.720095)" + private val LINESTRING_3 = + "LINESTRING (-104.806344 39.722425, -104.805854 39.722634, -104.805656 39.722647, -104.803749 39.722641, -104.803769 39.721849, -104.803806 39.721725, -104.804382 39.720865)" + private val POLYGON_1 = + "POLYGON ((-104.802246 39.720343, -104.802246 39.719753, -104.802183 39.719754, -104.802184 39.719719, -104.802138 39.719694, -104.802097 39.719691, -104.802096 39.719648, -104.801646 39.719648, -104.801644 39.719722, -104.80155 39.719723, -104.801549 39.720207, -104.801648 39.720207, -104.801648 39.720341, -104.802246 39.720343))" + private val POLYGON_2 = + "POLYGON ((-104.802259 39.719604, -104.80226 39.71955, -104.802281 39.719416, -104.802332 39.719372, -104.802081 39.71924, -104.802044 39.71929, -104.802027 39.719278, -104.802044 39.719229, -104.801785 39.719129, -104.801639 39.719413, -104.801649 39.719472, -104.801694 39.719524, -104.801753 39.71955, -104.80175 39.719606, -104.80194 39.719606, -104.801939 39.719555, -104.801977 39.719556, -104.801979 39.719606, -104.802259 39.719604), (-104.80213 39.71944, -104.802133 39.71949, -104.802148 39.71949, -104.80218 39.719473, -104.802187 39.719456, -104.802182 39.719439, -104.802088 39.719387, -104.802047 39.719427, -104.801858 39.719342, -104.801883 39.719294, -104.801832 39.719284, -104.801787 39.719298, -104.801763 39.719331, -104.801823 39.719352, -104.80179 39.71942, -104.801722 39.719404, -104.801715 39.719445, -104.801748 39.719484, -104.801809 39.719494, -104.801816 39.719439, -104.80213 39.71944))" + private val POLYGON_3 = + "POLYGON ((-104.802867 39.718122, -104.802369 39.717845, -104.802571 39.71763, -104.803066 39.717909, -104.802867 39.718122))" +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala new file mode 100644 index 00000000000..01306c1b452 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.spark.sql.Row +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.scalatest.BeforeAndAfterAll + +import java.util.Collections +import scala.collection.JavaConverters._ + +class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation: String = resourceFolder + "geoparquet/" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + describe("GeoParquet Metadata tests") { + it("Reading GeoParquet Metadata") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.collect() + assert(metadataArray.length > 1) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("version") == "1.0.0-dev")) + assert(metadataArray.exists(_.getAs[String]("primary_column") == "geometry")) + assert(metadataArray.exists { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + columnsMap != null && columnsMap + .containsKey("geometry") && columnsMap.get("geometry").isInstanceOf[Row] + }) + assert(metadataArray.forall { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + if (columnsMap == null || !columnsMap.containsKey("geometry")) true + else { + val columnMetadata = columnsMap.get("geometry").asInstanceOf[Row] + columnMetadata.getAs[String]("encoding") == "WKB" && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("bbox")) + .asScala + .forall(_.isInstanceOf[Double]) && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("geometry_types")) + .asScala + .forall(_.isInstanceOf[String]) && + columnMetadata.getAs[String]("crs").nonEmpty && + columnMetadata.getAs[String]("crs") != "null" + } + }) + } + + it("Reading GeoParquet Metadata with column pruning") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df + .selectExpr("path", "substring(primary_column, 1, 2) AS partial_primary_column") + .collect() + assert(metadataArray.length > 1) + assert(metadataArray.forall(_.length == 2)) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("partial_primary_column") == "ge")) + } + + it("Reading GeoParquet Metadata of plain parquet files") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.where("path LIKE '%plain.parquet'").collect() + assert(metadataArray.nonEmpty) + assert(metadataArray.forall(_.getAs[String]("path").endsWith("plain.parquet"))) + assert(metadataArray.forall(_.getAs[String]("version") == null)) + assert(metadataArray.forall(_.getAs[String]("primary_column") == null)) + assert(metadataArray.forall(_.getAs[String]("columns") == null)) + } + + it("Read GeoParquet without CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "") + } + + it("Read GeoParquet with null CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "null") + } + + it("Read GeoParquet with snake_case geometry column name and camelCase column name") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column_1", GeometryUDT(), nullable = false), + StructField("geomColumn2", GeometryUDT(), nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")) + assert(metadata.containsKey("geom_column_1")) + assert(!metadata.containsKey("geoColumn1")) + assert(metadata.containsKey("geomColumn2")) + assert(!metadata.containsKey("geom_column2")) + assert(!metadata.containsKey("geom_column_2")) + } + + it("Read GeoParquet with covering metadata") { + val dfMeta = sparkSession.read + .format("geoparquet.metadata") + .load(geoparquetdatalocation + "/example-1.1.0.parquet") + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + val covering = metadata.getAs[String]("covering") + assert(covering.nonEmpty) + Seq("bbox", "xmin", "ymin", "xmax", "ymax").foreach { key => + assert(covering contains key) + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala new file mode 100644 index 00000000000..6f873d0a087 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.scalatest.matchers.must.Matchers.be +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.prop.TableDrivenPropertyChecks + +/** + * Test suite for testing Sedona SQL support. + */ +class SQLSyntaxTestScala extends TestBaseScala with TableDrivenPropertyChecks { + + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession.conf.set("spark.sql.legacy.createHiveTableByDefault", "false") + } + + describe("Table creation DDL tests") { + + it("should be able to create a regular table without geometry column should work") { + sparkSession.sql("DROP TABLE IF EXISTS T_TEST_REGULAR") + sparkSession.sql("CREATE TABLE IF NOT EXISTS T_TEST_REGULAR (INT_COL INT)") + sparkSession.catalog.tableExists("T_TEST_REGULAR") should be(true) + sparkSession.sql("DROP TABLE IF EXISTS T_TEST_REGULAR") + sparkSession.catalog.tableExists("T_TEST_REGULAR") should be(false) + } + + it( + "should be able to create a regular table with geometry column should work without a workaround") { + try { + sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } + } + + it( + "should be able to create a regular table with regular and geometry column should work without a workaround") { + try { + sparkSession.sql( + "CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 00000000000..275bc3282f9 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,784 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.locationtech.jts.io.{WKTReader, WKTWriter} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files +import scala.collection.mutable + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(!row.isNullAt(row.fieldIndex("id"))) + } + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + var rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + // Using partition filters + rows = shapefileDf.where("part = 2").collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT()))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + + it("should read shapes of various types") { + // There are multiple directories under shapefiles/shapetypes, each containing a shapefile. + // We'll iterate over each directory and read the shapefile within it. + val shapeTypesDir = new File(resourceFolder + "shapefiles/shapetypes") + val shapeTypeDirs = shapeTypesDir.listFiles().filter(_.isDirectory) + shapeTypeDirs.foreach { shapeTypeDir => + val fileName = shapeTypeDir.getName + val hasZ = fileName.endsWith("zm") || fileName.endsWith("z") + val hasM = fileName.endsWith("zm") || fileName.endsWith("m") + val shapeType = + if (fileName.startsWith("point")) "POINT" + else if (fileName.startsWith("linestring")) "LINESTRING" + else if (fileName.startsWith("multipoint")) "MULTIPOINT" + else "POLYGON" + val expectedWktPrefix = + if (!hasZ && !hasM) shapeType + else { + shapeType + " " + (if (hasZ) "Z" else "") + (if (hasM) "M" else "") + } + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(shapeTypeDir.getAbsolutePath) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + val rows = shapefileDf.collect() + assert(rows.length > 0) + + // Validate the geometry type and WKT prefix + val wktWriter = new WKTWriter(4) + val rowsMap = mutable.Map[String, Geometry]() + rows.foreach { row => + val id = row.getAs[String]("id") + val geom = row.getAs[Geometry]("geometry") + val wkt = wktWriter.write(geom) + assert(wkt.startsWith(expectedWktPrefix)) + assert(geom != null) + rowsMap.put(id, geom) + } + + // Validate the geometry values by reading the CSV file containing the same data + val csvDf = sparkSession.read + .format("csv") + .option("header", "true") + .load(shapeTypeDir.getAbsolutePath + "/*.csv") + val wktReader = new WKTReader() + csvDf.collect().foreach { row => + val id = row.getAs[String]("id") + val wkt = row.getAs[String]("wkt") + val geom = wktReader.read(wkt) + assert(rowsMap(id).equals(geom)) + } + } + } + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala new file mode 100644 index 00000000000..2e3e9742222 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import io.minio.{MinioClient, PutObjectArgs} +import org.apache.log4j.{Level, Logger} +import org.apache.sedona.spark.SedonaContext +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.scalatest.{BeforeAndAfterAll, FunSpec} +import org.testcontainers.containers.MinIOContainer + +import java.io.FileInputStream + +import java.util.concurrent.ThreadLocalRandom + +trait TestBaseScala extends FunSpec with BeforeAndAfterAll { + Logger.getRootLogger().setLevel(Level.WARN) + Logger.getLogger("org.apache").setLevel(Level.WARN) + Logger.getLogger("com").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) + Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + + val keyParserExtension = "spark.sedona.enableParserExtension" + val warehouseLocation = System.getProperty("user.dir") + "/target/" + val sparkSession = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + // We need to be explicit about broadcasting in tests. + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") + .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) + // Disable Spark 4.1+ native geospatial functions that shadow Sedona's ST functions + .config("spark.sql.geospatial.enabled", "false") + .getOrCreate() + + val sparkSessionMinio = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.0") + .config( + "spark.hadoop.fs.s3a.aws.credentials.provider", + "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider") + .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .getOrCreate() + + val resourceFolder = System.getProperty("user.dir") + "/../common/src/test/resources/" + + override def beforeAll(): Unit = { + SedonaContext.create(sparkSession) + } + + override def afterAll(): Unit = { + // SedonaSQLRegistrator.dropAll(spark) + // spark.stop + } + + def loadCsv(path: String): DataFrame = { + sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) + } + + def withConf[T](conf: Map[String, String])(f: => T): T = { + val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key)) + conf.foreach { case (key, value) => sparkSession.conf.set(key, value) } + try { + f + } finally { + oldConf.foreach { case (key, value) => + value match { + case Some(v) => sparkSession.conf.set(key, v) + case None => sparkSession.conf.unset(key) + } + } + } + } + + def putFileIntoBucket( + bucketName: String, + key: String, + stream: FileInputStream, + client: MinioClient): Unit = { + val objectArguments = PutObjectArgs + .builder() + .bucket(bucketName) + .`object`(key) + .stream(stream, stream.available(), -1) + .build() + + client.putObject(objectArguments) + } + + def createMinioClient(container: MinIOContainer): MinioClient = { + MinioClient + .builder() + .endpoint(container.getS3URL) + .credentials(container.getUserName, container.getPassword) + .build() + } + + def adjustSparkSession(sparkSession: SparkSession, container: MinIOContainer): Unit = { + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.endpoint", container.getS3URL) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.access.key", container.getUserName) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.secret.key", container.getPassword) + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.connection.timeout", "2000") + + sparkSession.sparkContext.hadoopConfiguration.set("spark.sql.debug.maxToStringFields", "100") + sparkSession.sparkContext.hadoopConfiguration.set("fs.s3a.path.style.access", "true") + sparkSession.sparkContext.hadoopConfiguration + .set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala new file mode 100644 index 00000000000..8d41848de98 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.TestBaseScala +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction +import org.locationtech.jts.io.WKTReader +import org.scalatest.matchers.should.Matchers + +class StrategySuite extends TestBaseScala with Matchers { + val wktReader = new WKTReader() + + val spark: SparkSession = { + sparkSession.sparkContext.setLogLevel("ALL") + sparkSession + } + + import spark.implicits._ + + it("sedona geospatial UDF") { + val df = Seq( + (1, "value", wktReader.read("POINT(21 52)")), + (2, "value1", wktReader.read("POINT(20 50)")), + (3, "value2", wktReader.read("POINT(20 49)")), + (4, "value3", wktReader.read("POINT(20 48)")), + (5, "value4", wktReader.read("POINT(20 47)"))) + .toDF("id", "value", "geom") + .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + + df.count shouldEqual 5 + + df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + .as[String] + .collect() should contain theSameElementsAs Seq( + "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", + "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", + "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", + "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", + "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + } +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala new file mode 100644 index 00000000000..62733169288 --- /dev/null +++ b/spark/spark-4.1/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF +import org.apache.spark.TestUtils +import org.apache.spark.api.python._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.util.Utils + +import java.io.File +import java.nio.file.{Files, Paths} +import scala.sys.process.Process +import scala.jdk.CollectionConverters._ + +object ScalarUDF { + + val pythonExec: String = { + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3")) + if (TestUtils.testCommandAvailable(pythonExec)) { + pythonExec + } else { + "python" + } + } + + private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") + protected lazy val sparkHome: String = { + sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) + } + + private lazy val py4jPath = + Paths.get(sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath + private[spark] lazy val pysparkPythonPath = s"$py4jPath" + + private lazy val isPythonAvailable: Boolean = TestUtils.testCommandAvailable(pythonExec) + + lazy val pythonVer: String = if (isPythonAvailable) { + Process( + Seq(pythonExec, "-c", "import sys; print('%d.%d' % sys.version_info[:2])"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!.trim() + } else { + throw new RuntimeException(s"Python executable [$pythonExec] is unavailable.") + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) + finally Utils.deleteRecursively(path) + } + + val pandasFunc: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + println(path) + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import IntegerType + |from shapely.geometry import Point + |from sedona.sql.types import GeometryType + |from pyspark.serializers import CloudPickleSerializer + |from sedona.utils import geometry_serde + |from shapely import box + |f = open('$path', 'wb'); + |def w(x): + | def apply_function(w): + | geom, offset = geometry_serde.deserialize(w) + | bounds = geom.buffer(1).bounds + | x = box(*bounds) + | return geometry_serde.serialize(x) + | return x.apply(apply_function) + |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } + + private val workerEnv = new java.util.HashMap[String, String]() + workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + + val geoPandasScalaFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "geospatial_udf", + func = SimplePythonFunction( + command = pandasFunc, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = GeometryUDT(), + pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, + udfDeterministic = true) +}