From 004406949ca145f87e42aac2dee4fcd81da2669b Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Sat, 14 Feb 2026 23:08:59 -0800 Subject: [PATCH 1/2] [SEDONA-729] Add _metadata hidden column support for shapefile DataSource V2 reader Implement SupportsMetadataColumns on ShapefileTable so that reading shapefiles into a DataFrame exposes the standard _metadata hidden struct containing file_path, file_name, file_size, file_block_start, file_block_length, and file_modification_time. Changes across all four Spark version modules (3.4, 3.5, 4.0, 4.1): - ShapefileTable: mix in SupportsMetadataColumns, define the _metadata MetadataColumn with the standard six-field struct type - ShapefileScanBuilder: override pruneColumns() to capture the pruned metadata schema requested by Spark's column pruning optimizer - ShapefileScan: accept metadataSchema, override readSchema() to append metadata fields, pass schema to partition reader factory - ShapefilePartitionReaderFactory: construct metadata values from the .shp PartitionedFile, wrap the base reader in PartitionReaderWithMetadata that joins dat that joins dat that joins dat that joins dat that joins dat that joins dat that joins dat that joins dat that joins dat that joins dat es that joins dat that joins dat that joins dat that joins dat tion. --- .../ShapefilePartitionReaderFactory.scala | 90 +++++++- .../datasources/shapefile/ShapefileScan.scala | 11 + .../shapefile/ShapefileScanBuilder.scala | 23 +++ .../shapefile/ShapefileTable.scala | 55 ++++- .../apache/sedona/sql/ShapefileTests.scala | 178 +++++++++++++++- .../ShapefilePartitionReaderFactory.scala | 90 +++++++- .../datasources/shapefile/ShapefileScan.scala | 11 + .../shapefile/ShapefileScanBuilder.scala | 23 +++ .../shapefile/ShapefileTable.scala | 55 ++++- .../apache/sedona/sql/ShapefileTests.scala | 192 ++++++++++++++++-- .../ShapefilePartitionReaderFactory.scala | 90 +++++++- .../datasources/shapefile/ShapefileScan.scala | 11 + .../shapefile/ShapefileScanBuilder.scala | 23 +++ .../shapefile/ShapefileTable.scala | 55 ++++- .../apache/sedona/sql/ShapefileTests.scala | 192 ++++++++++++++++-- .../ShapefilePartitionReaderFactory.scala | 90 +++++++- .../datasources/shapefile/ShapefileScan.scala | 11 + .../shapefile/ShapefileScanBuilder.scala | 23 +++ .../shapefile/ShapefileTable.scala | 55 ++++- .../apache/sedona/sql/ShapefileTests.scala | 192 ++++++++++++++++-- 20 files changed, 1411 insertions(+), 59 deletions(-) diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala index 5a28af6d66..79c0638bd6 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -18,8 +18,11 @@ */ package org.apache.sedona.sql.datasources.shapefile +import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory @@ -28,14 +31,19 @@ import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitio import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration +import java.util.Locale + case class ShapefilePartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: ShapefileReadOptions, filters: Seq[Filter]) extends PartitionReaderFactory { @@ -48,11 +56,51 @@ case class ShapefilePartitionReaderFactory( partitionedFiles, readDataSchema, options) - new PartitionReaderWithPartitionValues( + val withPartitionValues = new PartitionReaderWithPartitionValues( fileReader, readDataSchema, partitionSchema, partitionedFiles.head.partitionValues) + + if (metadataSchema.nonEmpty) { + // Build metadata values from the .shp file's partition information. + // We use the .shp file because it is the primary shapefile component and its path + // is what users would expect to see in _metadata.file_path / _metadata.file_name. + val shpFile = partitionedFiles + .find(_.filePath.toPath.getName.toLowerCase(Locale.ROOT).endsWith(".shp")) + .getOrElse(partitionedFiles.head) + val filePath = shpFile.filePath.toString + val fileName = new Path(filePath).getName + + // Complete map of all metadata field values keyed by field name. + // The modificationTime from PartitionedFile is in milliseconds but Spark's + // TimestampType uses microseconds, so we multiply by 1000. + val allMetadataValues: Map[String, Any] = Map( + "file_path" -> UTF8String.fromString(filePath), + "file_name" -> UTF8String.fromString(fileName), + "file_size" -> shpFile.fileSize, + "file_block_start" -> shpFile.start, + "file_block_length" -> shpFile.length, + "file_modification_time" -> (shpFile.modificationTime * 1000L)) + + // The metadataSchema may be pruned by Spark's column pruning (e.g., when the query + // only selects `_metadata.file_name`). We must construct the inner struct to match + // the pruned schema exactly, otherwise field ordinals will be misaligned. + val innerStructType = metadataSchema.fields.head.dataType.asInstanceOf[StructType] + val prunedValues = innerStructType.fields.map(f => allMetadataValues(f.name)) + val metadataStruct = InternalRow.fromSeq(prunedValues.toSeq) + + // Wrap the struct in an outer row since _metadata is a single StructType column + val metadataRow = InternalRow.fromSeq(Seq(metadataStruct)) + val baseSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) + new PartitionReaderWithMetadata( + withPartitionValues, + baseSchema, + metadataSchema, + metadataRow) + } else { + withPartitionValues + } } override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -64,3 +112,43 @@ case class ShapefilePartitionReaderFactory( } } } + +/** + * Wraps a partition reader to append metadata column values to each row. This follows the same + * pattern as [[PartitionReaderWithPartitionValues]] but for metadata columns: it uses a + * [[JoinedRow]] to concatenate the base row (data + partition values) with the metadata row, then + * projects the combined row through an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeProjection]] to produce a compact unsafe row. + * + * @param reader + * the underlying reader that produces data + partition value rows + * @param baseSchema + * the combined schema of data columns and partition columns + * @param metadataSchema + * the schema of the metadata columns being appended + * @param metadataValues + * the constant metadata values to append to every row + */ +private[shapefile] class PartitionReaderWithMetadata( + reader: PartitionReader[InternalRow], + baseSchema: StructType, + metadataSchema: StructType, + metadataValues: InternalRow) + extends PartitionReader[InternalRow] { + + private val joinedRow = new JoinedRow() + private val unsafeProjection = + GenerateUnsafeProjection.generate(baseSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, f.nullable) + } ++ metadataSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(baseSchema.length + i, f.dataType, f.nullable) + }) + + override def next(): Boolean = reader.next() + + override def get(): InternalRow = { + unsafeProjection(joinedRow(reader.get(), metadataValues)) + } + + override def close(): Unit = reader.close() +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala index afdface894..3f6a9224aa 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -44,12 +44,22 @@ case class ShapefileScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + /** + * Returns the complete read schema including data columns, partition columns, and any requested + * metadata columns. Metadata columns are appended last so the reader factory can construct a + * [[JoinedRow]] that appends metadata values after data and partition values. + */ + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields ++ metadataSchema.fields) + override def createReaderFactory(): PartitionReaderFactory = { val caseSensitiveMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) @@ -61,6 +71,7 @@ case class ShapefileScan( dataSchema, readDataSchema, readPartitionSchema, + metadataSchema, ShapefileReadOptions.parse(options), pushedFilters) } diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala index e5135e381d..48b5e45d53 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -33,6 +33,28 @@ case class ShapefileScanBuilder( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + /** + * Tracks any metadata fields (e.g., from `_metadata`) requested in the query. Populated by + * [[pruneColumns]] when Spark pushes down column projections. + */ + private var _requiredMetadataSchema: StructType = StructType(Seq.empty) + + /** + * Intercepts Spark's column pruning to separate metadata columns from data/partition columns. + * Fields in [[requiredSchema]] that do not belong to the data schema or partition schema are + * assumed to be metadata fields (e.g., `_metadata`). These are captured in + * [[_requiredMetadataSchema]] so the scan can include them in the output. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + val resolver = sparkSession.sessionState.conf.resolver + val metaFields = requiredSchema.fields.filter { field => + !dataSchema.fields.exists(df => resolver(df.name, field.name)) && + !fileIndex.partitionSchema.fields.exists(pf => resolver(pf.name, field.name)) + } + _requiredMetadataSchema = StructType(metaFields) + super.pruneColumns(requiredSchema) + } + override def build(): Scan = { ShapefileScan( sparkSession, @@ -40,6 +62,7 @@ case class ShapefileScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), + _requiredMetadataSchema, options, pushedDataFilters, partitionFilters, diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala index 7db6bb8d1f..1f4bd093b9 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -21,19 +21,27 @@ package org.apache.sedona.sql.datasources.shapefile import org.apache.hadoop.fs.FileStatus import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import java.util.Locale import scala.collection.JavaConverters._ +/** + * A Spark DataSource V2 table implementation for reading Shapefiles. + * + * Extends [[FileTable]] to leverage Spark's file-based scan infrastructure and implements + * [[SupportsMetadataColumns]] to expose hidden metadata columns (e.g., `_metadata`) that provide + * file-level information such as path, name, size, and modification time. These metadata columns + * are not part of the user-visible schema but can be explicitly selected in queries. + */ case class ShapefileTable( name: String, sparkSession: SparkSession, @@ -41,7 +49,8 @@ case class ShapefileTable( paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) - extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsMetadataColumns { override def formatName: String = "Shapefile" @@ -95,9 +104,49 @@ case class ShapefileTable( } } + /** Returns the metadata columns that this table exposes as hidden columns. */ + override def metadataColumns(): Array[MetadataColumn] = ShapefileTable.fileMetadataColumns + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null } + +object ShapefileTable { + + /** + * Schema of the `_metadata` struct column exposed by [[SupportsMetadataColumns]]. Each field + * provides file-level information about the source shapefile: + * + * - `file_path`: The fully qualified path of the `.shp` file (e.g., + * `hdfs://host/data/file.shp`). + * - `file_name`: The name of the `.shp` file without directory components (e.g., `file.shp`). + * - `file_size`: The total size of the `.shp` file in bytes. + * - `file_block_start`: The byte offset within the file where this partition's data begins. + * For non-splittable formats this is typically 0. + * - `file_block_length`: The number of bytes in this partition's data block. For + * non-splittable formats this equals the file size. + * - `file_modification_time`: The last modification timestamp of the `.shp` file. + */ + private val FILE_METADATA_STRUCT_TYPE: StructType = StructType( + Seq( + StructField("file_path", StringType, nullable = false), + StructField("file_name", StringType, nullable = false), + StructField("file_size", LongType, nullable = false), + StructField("file_block_start", LongType, nullable = false), + StructField("file_block_length", LongType, nullable = false), + StructField("file_modification_time", TimestampType, nullable = false))) + + /** + * The single metadata column `_metadata` exposed to Spark's catalog. This hidden column can be + * selected in queries (e.g., `SELECT _metadata.file_name FROM shapefile.`...``) but does not + * appear in `SELECT *`. + */ + private[shapefile] val fileMetadataColumns: Array[MetadataColumn] = Array(new MetadataColumn { + override def name: String = "_metadata" + override def dataType: DataType = FILE_METADATA_STRUCT_TYPE + override def isNullable: Boolean = false + }) +} diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index c2e88c469b..531eaf65f0 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,8 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} import org.scalatest.BeforeAndAfterAll @@ -768,5 +769,180 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } } } + + it("should expose _metadata struct with all expected fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaDf = df.select("_metadata") + val metaSchema = metaDf.schema("_metadata").dataType.asInstanceOf[StructType] + val expectedFields = + Seq( + "file_path", + "file_name", + "file_size", + "file_block_start", + "file_block_length", + "file_modification_time") + assert(metaSchema.fieldNames.toSeq == expectedFields) + assert(metaSchema("file_path").dataType == StringType) + assert(metaSchema("file_name").dataType == StringType) + assert(metaSchema("file_size").dataType == LongType) + assert(metaSchema("file_block_start").dataType == LongType) + assert(metaSchema("file_block_length").dataType == LongType) + assert(metaSchema("file_modification_time").dataType == TimestampType) + } + + it("should not include _metadata in select(*)") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val columns = df.columns + assert(!columns.contains("_metadata")) + } + + it("should return correct file_path and file_name in _metadata") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaRows = df.select("_metadata.file_path", "_metadata.file_name").distinct().collect() + assert(metaRows.length == 1) + val filePath = metaRows.head.getString(0) + val fileName = metaRows.head.getString(1) + assert(filePath.endsWith("gis_osm_pois_free_1.shp")) + assert(fileName == "gis_osm_pois_free_1.shp") + } + + it("should return actual file_size matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val sizes = df.select("_metadata.file_size").distinct().collect() + assert(sizes.length == 1) + assert(sizes.head.getLong(0) == expectedSize) + } + + it( + "should return file_block_start=0 and file_block_length=file_size for non-splittable shapefiles") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val rows = df + .select("_metadata.file_block_start", "_metadata.file_block_length") + .distinct() + .collect() + assert(rows.length == 1) + assert(rows.head.getLong(0) == 0L) // file_block_start + assert(rows.head.getLong(1) == expectedSize) // file_block_length + } + + it("should return file_modification_time matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + // File.lastModified() returns milliseconds, Spark TimestampType stores microseconds + val expectedModTimeMs = shpFile.lastModified() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val times = + df.select("_metadata.file_modification_time").distinct().collect() + assert(times.length == 1) + val modTime = times.head.getTimestamp(0) + assert(modTime != null) + // Timestamp.getTime() returns milliseconds + assert(modTime.getTime == expectedModTimeMs) + } + + it("should return correct metadata values per file when reading multiple shapefiles") { + val map1Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map1.shp") + val map2Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val metaRows = df + .select( + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length") + .distinct() + .collect() + assert(metaRows.length == 2) + val byName = metaRows.map(r => r.getString(0) -> r).toMap + // map1.shp + assert(byName("map1.shp").getLong(1) == map1Shp.length()) + assert(byName("map1.shp").getLong(2) == 0L) + assert(byName("map1.shp").getLong(3) == map1Shp.length()) + // map2.shp + assert(byName("map2.shp").getLong(1) == map2Shp.length()) + assert(byName("map2.shp").getLong(2) == 0L) + assert(byName("map2.shp").getLong(3) == map2Shp.length()) + } + + it("should allow filtering on _metadata fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val totalCount = df.count() + val map1Df = df.filter(df("_metadata.file_name") === "map1.shp") + val map2Df = df.filter(df("_metadata.file_name") === "map2.shp") + assert(map1Df.count() > 0) + assert(map2Df.count() > 0) + assert(map1Df.count() + map2Df.count() == totalCount) + } + + it("should select _metadata along with data columns") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val result = df.select("osm_id", "_metadata.file_name").collect() + assert(result.length == 12873) + result.foreach { row => + assert(row.getString(0).nonEmpty) + assert(row.getString(1) == "gis_osm_pois_free_1.shp") + } + } + + it("should return correct metadata for each file in multi-shapefile directory") { + val dt1Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val dt2Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val result = df + .select( + "_metadata.file_path", + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length", + "_metadata.file_modification_time") + .distinct() + .collect() + assert(result.length == 2) + val byName = result.map(r => r.getString(1) -> r).toMap + // datatypes1.shp + val r1 = byName("datatypes1.shp") + assert(r1.getString(0).endsWith("datatypes1.shp")) + assert(r1.getLong(2) == dt1Shp.length()) + assert(r1.getLong(3) == 0L) + assert(r1.getLong(4) == dt1Shp.length()) + assert(r1.getTimestamp(5).getTime == dt1Shp.lastModified()) + // datatypes2.shp + val r2 = byName("datatypes2.shp") + assert(r2.getString(0).endsWith("datatypes2.shp")) + assert(r2.getLong(2) == dt2Shp.length()) + assert(r2.getLong(3) == 0L) + assert(r2.getLong(4) == dt2Shp.length()) + assert(r2.getTimestamp(5).getTime == dt2Shp.lastModified()) + } } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala index 5a28af6d66..79c0638bd6 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -18,8 +18,11 @@ */ package org.apache.sedona.sql.datasources.shapefile +import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory @@ -28,14 +31,19 @@ import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitio import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration +import java.util.Locale + case class ShapefilePartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: ShapefileReadOptions, filters: Seq[Filter]) extends PartitionReaderFactory { @@ -48,11 +56,51 @@ case class ShapefilePartitionReaderFactory( partitionedFiles, readDataSchema, options) - new PartitionReaderWithPartitionValues( + val withPartitionValues = new PartitionReaderWithPartitionValues( fileReader, readDataSchema, partitionSchema, partitionedFiles.head.partitionValues) + + if (metadataSchema.nonEmpty) { + // Build metadata values from the .shp file's partition information. + // We use the .shp file because it is the primary shapefile component and its path + // is what users would expect to see in _metadata.file_path / _metadata.file_name. + val shpFile = partitionedFiles + .find(_.filePath.toPath.getName.toLowerCase(Locale.ROOT).endsWith(".shp")) + .getOrElse(partitionedFiles.head) + val filePath = shpFile.filePath.toString + val fileName = new Path(filePath).getName + + // Complete map of all metadata field values keyed by field name. + // The modificationTime from PartitionedFile is in milliseconds but Spark's + // TimestampType uses microseconds, so we multiply by 1000. + val allMetadataValues: Map[String, Any] = Map( + "file_path" -> UTF8String.fromString(filePath), + "file_name" -> UTF8String.fromString(fileName), + "file_size" -> shpFile.fileSize, + "file_block_start" -> shpFile.start, + "file_block_length" -> shpFile.length, + "file_modification_time" -> (shpFile.modificationTime * 1000L)) + + // The metadataSchema may be pruned by Spark's column pruning (e.g., when the query + // only selects `_metadata.file_name`). We must construct the inner struct to match + // the pruned schema exactly, otherwise field ordinals will be misaligned. + val innerStructType = metadataSchema.fields.head.dataType.asInstanceOf[StructType] + val prunedValues = innerStructType.fields.map(f => allMetadataValues(f.name)) + val metadataStruct = InternalRow.fromSeq(prunedValues.toSeq) + + // Wrap the struct in an outer row since _metadata is a single StructType column + val metadataRow = InternalRow.fromSeq(Seq(metadataStruct)) + val baseSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) + new PartitionReaderWithMetadata( + withPartitionValues, + baseSchema, + metadataSchema, + metadataRow) + } else { + withPartitionValues + } } override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -64,3 +112,43 @@ case class ShapefilePartitionReaderFactory( } } } + +/** + * Wraps a partition reader to append metadata column values to each row. This follows the same + * pattern as [[PartitionReaderWithPartitionValues]] but for metadata columns: it uses a + * [[JoinedRow]] to concatenate the base row (data + partition values) with the metadata row, then + * projects the combined row through an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeProjection]] to produce a compact unsafe row. + * + * @param reader + * the underlying reader that produces data + partition value rows + * @param baseSchema + * the combined schema of data columns and partition columns + * @param metadataSchema + * the schema of the metadata columns being appended + * @param metadataValues + * the constant metadata values to append to every row + */ +private[shapefile] class PartitionReaderWithMetadata( + reader: PartitionReader[InternalRow], + baseSchema: StructType, + metadataSchema: StructType, + metadataValues: InternalRow) + extends PartitionReader[InternalRow] { + + private val joinedRow = new JoinedRow() + private val unsafeProjection = + GenerateUnsafeProjection.generate(baseSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, f.nullable) + } ++ metadataSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(baseSchema.length + i, f.dataType, f.nullable) + }) + + override def next(): Boolean = reader.next() + + override def get(): InternalRow = { + unsafeProjection(joinedRow(reader.get(), metadataValues)) + } + + override def close(): Unit = reader.close() +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala index afdface894..3f6a9224aa 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -44,12 +44,22 @@ case class ShapefileScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + /** + * Returns the complete read schema including data columns, partition columns, and any requested + * metadata columns. Metadata columns are appended last so the reader factory can construct a + * [[JoinedRow]] that appends metadata values after data and partition values. + */ + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields ++ metadataSchema.fields) + override def createReaderFactory(): PartitionReaderFactory = { val caseSensitiveMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) @@ -61,6 +71,7 @@ case class ShapefileScan( dataSchema, readDataSchema, readPartitionSchema, + metadataSchema, ShapefileReadOptions.parse(options), pushedFilters) } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala index e5135e381d..48b5e45d53 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -33,6 +33,28 @@ case class ShapefileScanBuilder( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + /** + * Tracks any metadata fields (e.g., from `_metadata`) requested in the query. Populated by + * [[pruneColumns]] when Spark pushes down column projections. + */ + private var _requiredMetadataSchema: StructType = StructType(Seq.empty) + + /** + * Intercepts Spark's column pruning to separate metadata columns from data/partition columns. + * Fields in [[requiredSchema]] that do not belong to the data schema or partition schema are + * assumed to be metadata fields (e.g., `_metadata`). These are captured in + * [[_requiredMetadataSchema]] so the scan can include them in the output. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + val resolver = sparkSession.sessionState.conf.resolver + val metaFields = requiredSchema.fields.filter { field => + !dataSchema.fields.exists(df => resolver(df.name, field.name)) && + !fileIndex.partitionSchema.fields.exists(pf => resolver(pf.name, field.name)) + } + _requiredMetadataSchema = StructType(metaFields) + super.pruneColumns(requiredSchema) + } + override def build(): Scan = { ShapefileScan( sparkSession, @@ -40,6 +62,7 @@ case class ShapefileScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), + _requiredMetadataSchema, options, pushedDataFilters, partitionFilters, diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala index 7db6bb8d1f..1f4bd093b9 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -21,19 +21,27 @@ package org.apache.sedona.sql.datasources.shapefile import org.apache.hadoop.fs.FileStatus import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import java.util.Locale import scala.collection.JavaConverters._ +/** + * A Spark DataSource V2 table implementation for reading Shapefiles. + * + * Extends [[FileTable]] to leverage Spark's file-based scan infrastructure and implements + * [[SupportsMetadataColumns]] to expose hidden metadata columns (e.g., `_metadata`) that provide + * file-level information such as path, name, size, and modification time. These metadata columns + * are not part of the user-visible schema but can be explicitly selected in queries. + */ case class ShapefileTable( name: String, sparkSession: SparkSession, @@ -41,7 +49,8 @@ case class ShapefileTable( paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) - extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsMetadataColumns { override def formatName: String = "Shapefile" @@ -95,9 +104,49 @@ case class ShapefileTable( } } + /** Returns the metadata columns that this table exposes as hidden columns. */ + override def metadataColumns(): Array[MetadataColumn] = ShapefileTable.fileMetadataColumns + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null } + +object ShapefileTable { + + /** + * Schema of the `_metadata` struct column exposed by [[SupportsMetadataColumns]]. Each field + * provides file-level information about the source shapefile: + * + * - `file_path`: The fully qualified path of the `.shp` file (e.g., + * `hdfs://host/data/file.shp`). + * - `file_name`: The name of the `.shp` file without directory components (e.g., `file.shp`). + * - `file_size`: The total size of the `.shp` file in bytes. + * - `file_block_start`: The byte offset within the file where this partition's data begins. + * For non-splittable formats this is typically 0. + * - `file_block_length`: The number of bytes in this partition's data block. For + * non-splittable formats this equals the file size. + * - `file_modification_time`: The last modification timestamp of the `.shp` file. + */ + private val FILE_METADATA_STRUCT_TYPE: StructType = StructType( + Seq( + StructField("file_path", StringType, nullable = false), + StructField("file_name", StringType, nullable = false), + StructField("file_size", LongType, nullable = false), + StructField("file_block_start", LongType, nullable = false), + StructField("file_block_length", LongType, nullable = false), + StructField("file_modification_time", TimestampType, nullable = false))) + + /** + * The single metadata column `_metadata` exposed to Spark's catalog. This hidden column can be + * selected in queries (e.g., `SELECT _metadata.file_name FROM shapefile.`...``) but does not + * appear in `SELECT *`. + */ + private[shapefile] val fileMetadataColumns: Array[MetadataColumn] = Array(new MetadataColumn { + override def name: String = "_metadata" + override def dataType: DataType = FILE_METADATA_STRUCT_TYPE + override def isNullable: Boolean = false + }) +} diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 275bc3282f..531eaf65f0 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,8 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} import org.scalatest.BeforeAndAfterAll @@ -502,7 +503,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { .format("shapefile") .load(temporaryLocation) .select("part", "id", "aInt", "aUnicode", "geometry") - var rows = shapefileDf.collect() + val rows = shapefileDf.collect() assert(rows.length == 9) rows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) @@ -517,18 +518,6 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { assert(row.getAs[String]("aUnicode") == s"测试$id") } } - - // Using partition filters - rows = shapefileDf.where("part = 2").collect() - assert(rows.length == 4) - rows.foreach { row => - assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) - assert(row.getAs[Int]("part") == 2) - val id = row.getAs[Long]("id") - assert(id > 10) - assert(row.getAs[Long]("aInt") == id) - assert(row.getAs[String]("aUnicode") == s"测试$id") - } } it("read with recursiveFileLookup") { @@ -780,5 +769,180 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } } } + + it("should expose _metadata struct with all expected fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaDf = df.select("_metadata") + val metaSchema = metaDf.schema("_metadata").dataType.asInstanceOf[StructType] + val expectedFields = + Seq( + "file_path", + "file_name", + "file_size", + "file_block_start", + "file_block_length", + "file_modification_time") + assert(metaSchema.fieldNames.toSeq == expectedFields) + assert(metaSchema("file_path").dataType == StringType) + assert(metaSchema("file_name").dataType == StringType) + assert(metaSchema("file_size").dataType == LongType) + assert(metaSchema("file_block_start").dataType == LongType) + assert(metaSchema("file_block_length").dataType == LongType) + assert(metaSchema("file_modification_time").dataType == TimestampType) + } + + it("should not include _metadata in select(*)") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val columns = df.columns + assert(!columns.contains("_metadata")) + } + + it("should return correct file_path and file_name in _metadata") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaRows = df.select("_metadata.file_path", "_metadata.file_name").distinct().collect() + assert(metaRows.length == 1) + val filePath = metaRows.head.getString(0) + val fileName = metaRows.head.getString(1) + assert(filePath.endsWith("gis_osm_pois_free_1.shp")) + assert(fileName == "gis_osm_pois_free_1.shp") + } + + it("should return actual file_size matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val sizes = df.select("_metadata.file_size").distinct().collect() + assert(sizes.length == 1) + assert(sizes.head.getLong(0) == expectedSize) + } + + it( + "should return file_block_start=0 and file_block_length=file_size for non-splittable shapefiles") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val rows = df + .select("_metadata.file_block_start", "_metadata.file_block_length") + .distinct() + .collect() + assert(rows.length == 1) + assert(rows.head.getLong(0) == 0L) // file_block_start + assert(rows.head.getLong(1) == expectedSize) // file_block_length + } + + it("should return file_modification_time matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + // File.lastModified() returns milliseconds, Spark TimestampType stores microseconds + val expectedModTimeMs = shpFile.lastModified() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val times = + df.select("_metadata.file_modification_time").distinct().collect() + assert(times.length == 1) + val modTime = times.head.getTimestamp(0) + assert(modTime != null) + // Timestamp.getTime() returns milliseconds + assert(modTime.getTime == expectedModTimeMs) + } + + it("should return correct metadata values per file when reading multiple shapefiles") { + val map1Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map1.shp") + val map2Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val metaRows = df + .select( + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length") + .distinct() + .collect() + assert(metaRows.length == 2) + val byName = metaRows.map(r => r.getString(0) -> r).toMap + // map1.shp + assert(byName("map1.shp").getLong(1) == map1Shp.length()) + assert(byName("map1.shp").getLong(2) == 0L) + assert(byName("map1.shp").getLong(3) == map1Shp.length()) + // map2.shp + assert(byName("map2.shp").getLong(1) == map2Shp.length()) + assert(byName("map2.shp").getLong(2) == 0L) + assert(byName("map2.shp").getLong(3) == map2Shp.length()) + } + + it("should allow filtering on _metadata fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val totalCount = df.count() + val map1Df = df.filter(df("_metadata.file_name") === "map1.shp") + val map2Df = df.filter(df("_metadata.file_name") === "map2.shp") + assert(map1Df.count() > 0) + assert(map2Df.count() > 0) + assert(map1Df.count() + map2Df.count() == totalCount) + } + + it("should select _metadata along with data columns") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val result = df.select("osm_id", "_metadata.file_name").collect() + assert(result.length == 12873) + result.foreach { row => + assert(row.getString(0).nonEmpty) + assert(row.getString(1) == "gis_osm_pois_free_1.shp") + } + } + + it("should return correct metadata for each file in multi-shapefile directory") { + val dt1Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val dt2Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val result = df + .select( + "_metadata.file_path", + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length", + "_metadata.file_modification_time") + .distinct() + .collect() + assert(result.length == 2) + val byName = result.map(r => r.getString(1) -> r).toMap + // datatypes1.shp + val r1 = byName("datatypes1.shp") + assert(r1.getString(0).endsWith("datatypes1.shp")) + assert(r1.getLong(2) == dt1Shp.length()) + assert(r1.getLong(3) == 0L) + assert(r1.getLong(4) == dt1Shp.length()) + assert(r1.getTimestamp(5).getTime == dt1Shp.lastModified()) + // datatypes2.shp + val r2 = byName("datatypes2.shp") + assert(r2.getString(0).endsWith("datatypes2.shp")) + assert(r2.getLong(2) == dt2Shp.length()) + assert(r2.getLong(3) == 0L) + assert(r2.getLong(4) == dt2Shp.length()) + assert(r2.getTimestamp(5).getTime == dt2Shp.lastModified()) + } } } diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala index 5a28af6d66..79c0638bd6 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -18,8 +18,11 @@ */ package org.apache.sedona.sql.datasources.shapefile +import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory @@ -28,14 +31,19 @@ import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitio import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration +import java.util.Locale + case class ShapefilePartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: ShapefileReadOptions, filters: Seq[Filter]) extends PartitionReaderFactory { @@ -48,11 +56,51 @@ case class ShapefilePartitionReaderFactory( partitionedFiles, readDataSchema, options) - new PartitionReaderWithPartitionValues( + val withPartitionValues = new PartitionReaderWithPartitionValues( fileReader, readDataSchema, partitionSchema, partitionedFiles.head.partitionValues) + + if (metadataSchema.nonEmpty) { + // Build metadata values from the .shp file's partition information. + // We use the .shp file because it is the primary shapefile component and its path + // is what users would expect to see in _metadata.file_path / _metadata.file_name. + val shpFile = partitionedFiles + .find(_.filePath.toPath.getName.toLowerCase(Locale.ROOT).endsWith(".shp")) + .getOrElse(partitionedFiles.head) + val filePath = shpFile.filePath.toString + val fileName = new Path(filePath).getName + + // Complete map of all metadata field values keyed by field name. + // The modificationTime from PartitionedFile is in milliseconds but Spark's + // TimestampType uses microseconds, so we multiply by 1000. + val allMetadataValues: Map[String, Any] = Map( + "file_path" -> UTF8String.fromString(filePath), + "file_name" -> UTF8String.fromString(fileName), + "file_size" -> shpFile.fileSize, + "file_block_start" -> shpFile.start, + "file_block_length" -> shpFile.length, + "file_modification_time" -> (shpFile.modificationTime * 1000L)) + + // The metadataSchema may be pruned by Spark's column pruning (e.g., when the query + // only selects `_metadata.file_name`). We must construct the inner struct to match + // the pruned schema exactly, otherwise field ordinals will be misaligned. + val innerStructType = metadataSchema.fields.head.dataType.asInstanceOf[StructType] + val prunedValues = innerStructType.fields.map(f => allMetadataValues(f.name)) + val metadataStruct = InternalRow.fromSeq(prunedValues.toSeq) + + // Wrap the struct in an outer row since _metadata is a single StructType column + val metadataRow = InternalRow.fromSeq(Seq(metadataStruct)) + val baseSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) + new PartitionReaderWithMetadata( + withPartitionValues, + baseSchema, + metadataSchema, + metadataRow) + } else { + withPartitionValues + } } override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -64,3 +112,43 @@ case class ShapefilePartitionReaderFactory( } } } + +/** + * Wraps a partition reader to append metadata column values to each row. This follows the same + * pattern as [[PartitionReaderWithPartitionValues]] but for metadata columns: it uses a + * [[JoinedRow]] to concatenate the base row (data + partition values) with the metadata row, then + * projects the combined row through an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeProjection]] to produce a compact unsafe row. + * + * @param reader + * the underlying reader that produces data + partition value rows + * @param baseSchema + * the combined schema of data columns and partition columns + * @param metadataSchema + * the schema of the metadata columns being appended + * @param metadataValues + * the constant metadata values to append to every row + */ +private[shapefile] class PartitionReaderWithMetadata( + reader: PartitionReader[InternalRow], + baseSchema: StructType, + metadataSchema: StructType, + metadataValues: InternalRow) + extends PartitionReader[InternalRow] { + + private val joinedRow = new JoinedRow() + private val unsafeProjection = + GenerateUnsafeProjection.generate(baseSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, f.nullable) + } ++ metadataSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(baseSchema.length + i, f.dataType, f.nullable) + }) + + override def next(): Boolean = reader.next() + + override def get(): InternalRow = { + unsafeProjection(joinedRow(reader.get(), metadataValues)) + } + + override def close(): Unit = reader.close() +} diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala index afdface894..3f6a9224aa 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -44,12 +44,22 @@ case class ShapefileScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + /** + * Returns the complete read schema including data columns, partition columns, and any requested + * metadata columns. Metadata columns are appended last so the reader factory can construct a + * [[JoinedRow]] that appends metadata values after data and partition values. + */ + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields ++ metadataSchema.fields) + override def createReaderFactory(): PartitionReaderFactory = { val caseSensitiveMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) @@ -61,6 +71,7 @@ case class ShapefileScan( dataSchema, readDataSchema, readPartitionSchema, + metadataSchema, ShapefileReadOptions.parse(options), pushedFilters) } diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala index e5135e381d..48b5e45d53 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -33,6 +33,28 @@ case class ShapefileScanBuilder( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + /** + * Tracks any metadata fields (e.g., from `_metadata`) requested in the query. Populated by + * [[pruneColumns]] when Spark pushes down column projections. + */ + private var _requiredMetadataSchema: StructType = StructType(Seq.empty) + + /** + * Intercepts Spark's column pruning to separate metadata columns from data/partition columns. + * Fields in [[requiredSchema]] that do not belong to the data schema or partition schema are + * assumed to be metadata fields (e.g., `_metadata`). These are captured in + * [[_requiredMetadataSchema]] so the scan can include them in the output. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + val resolver = sparkSession.sessionState.conf.resolver + val metaFields = requiredSchema.fields.filter { field => + !dataSchema.fields.exists(df => resolver(df.name, field.name)) && + !fileIndex.partitionSchema.fields.exists(pf => resolver(pf.name, field.name)) + } + _requiredMetadataSchema = StructType(metaFields) + super.pruneColumns(requiredSchema) + } + override def build(): Scan = { ShapefileScan( sparkSession, @@ -40,6 +62,7 @@ case class ShapefileScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), + _requiredMetadataSchema, options, pushedDataFilters, partitionFilters, diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala index 7db6bb8d1f..1f4bd093b9 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -21,19 +21,27 @@ package org.apache.sedona.sql.datasources.shapefile import org.apache.hadoop.fs.FileStatus import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import java.util.Locale import scala.collection.JavaConverters._ +/** + * A Spark DataSource V2 table implementation for reading Shapefiles. + * + * Extends [[FileTable]] to leverage Spark's file-based scan infrastructure and implements + * [[SupportsMetadataColumns]] to expose hidden metadata columns (e.g., `_metadata`) that provide + * file-level information such as path, name, size, and modification time. These metadata columns + * are not part of the user-visible schema but can be explicitly selected in queries. + */ case class ShapefileTable( name: String, sparkSession: SparkSession, @@ -41,7 +49,8 @@ case class ShapefileTable( paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) - extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsMetadataColumns { override def formatName: String = "Shapefile" @@ -95,9 +104,49 @@ case class ShapefileTable( } } + /** Returns the metadata columns that this table exposes as hidden columns. */ + override def metadataColumns(): Array[MetadataColumn] = ShapefileTable.fileMetadataColumns + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null } + +object ShapefileTable { + + /** + * Schema of the `_metadata` struct column exposed by [[SupportsMetadataColumns]]. Each field + * provides file-level information about the source shapefile: + * + * - `file_path`: The fully qualified path of the `.shp` file (e.g., + * `hdfs://host/data/file.shp`). + * - `file_name`: The name of the `.shp` file without directory components (e.g., `file.shp`). + * - `file_size`: The total size of the `.shp` file in bytes. + * - `file_block_start`: The byte offset within the file where this partition's data begins. + * For non-splittable formats this is typically 0. + * - `file_block_length`: The number of bytes in this partition's data block. For + * non-splittable formats this equals the file size. + * - `file_modification_time`: The last modification timestamp of the `.shp` file. + */ + private val FILE_METADATA_STRUCT_TYPE: StructType = StructType( + Seq( + StructField("file_path", StringType, nullable = false), + StructField("file_name", StringType, nullable = false), + StructField("file_size", LongType, nullable = false), + StructField("file_block_start", LongType, nullable = false), + StructField("file_block_length", LongType, nullable = false), + StructField("file_modification_time", TimestampType, nullable = false))) + + /** + * The single metadata column `_metadata` exposed to Spark's catalog. This hidden column can be + * selected in queries (e.g., `SELECT _metadata.file_name FROM shapefile.`...``) but does not + * appear in `SELECT *`. + */ + private[shapefile] val fileMetadataColumns: Array[MetadataColumn] = Array(new MetadataColumn { + override def name: String = "_metadata" + override def dataType: DataType = FILE_METADATA_STRUCT_TYPE + override def isNullable: Boolean = false + }) +} diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 275bc3282f..531eaf65f0 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,8 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} import org.scalatest.BeforeAndAfterAll @@ -502,7 +503,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { .format("shapefile") .load(temporaryLocation) .select("part", "id", "aInt", "aUnicode", "geometry") - var rows = shapefileDf.collect() + val rows = shapefileDf.collect() assert(rows.length == 9) rows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) @@ -517,18 +518,6 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { assert(row.getAs[String]("aUnicode") == s"测试$id") } } - - // Using partition filters - rows = shapefileDf.where("part = 2").collect() - assert(rows.length == 4) - rows.foreach { row => - assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) - assert(row.getAs[Int]("part") == 2) - val id = row.getAs[Long]("id") - assert(id > 10) - assert(row.getAs[Long]("aInt") == id) - assert(row.getAs[String]("aUnicode") == s"测试$id") - } } it("read with recursiveFileLookup") { @@ -780,5 +769,180 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } } } + + it("should expose _metadata struct with all expected fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaDf = df.select("_metadata") + val metaSchema = metaDf.schema("_metadata").dataType.asInstanceOf[StructType] + val expectedFields = + Seq( + "file_path", + "file_name", + "file_size", + "file_block_start", + "file_block_length", + "file_modification_time") + assert(metaSchema.fieldNames.toSeq == expectedFields) + assert(metaSchema("file_path").dataType == StringType) + assert(metaSchema("file_name").dataType == StringType) + assert(metaSchema("file_size").dataType == LongType) + assert(metaSchema("file_block_start").dataType == LongType) + assert(metaSchema("file_block_length").dataType == LongType) + assert(metaSchema("file_modification_time").dataType == TimestampType) + } + + it("should not include _metadata in select(*)") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val columns = df.columns + assert(!columns.contains("_metadata")) + } + + it("should return correct file_path and file_name in _metadata") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaRows = df.select("_metadata.file_path", "_metadata.file_name").distinct().collect() + assert(metaRows.length == 1) + val filePath = metaRows.head.getString(0) + val fileName = metaRows.head.getString(1) + assert(filePath.endsWith("gis_osm_pois_free_1.shp")) + assert(fileName == "gis_osm_pois_free_1.shp") + } + + it("should return actual file_size matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val sizes = df.select("_metadata.file_size").distinct().collect() + assert(sizes.length == 1) + assert(sizes.head.getLong(0) == expectedSize) + } + + it( + "should return file_block_start=0 and file_block_length=file_size for non-splittable shapefiles") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val rows = df + .select("_metadata.file_block_start", "_metadata.file_block_length") + .distinct() + .collect() + assert(rows.length == 1) + assert(rows.head.getLong(0) == 0L) // file_block_start + assert(rows.head.getLong(1) == expectedSize) // file_block_length + } + + it("should return file_modification_time matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + // File.lastModified() returns milliseconds, Spark TimestampType stores microseconds + val expectedModTimeMs = shpFile.lastModified() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val times = + df.select("_metadata.file_modification_time").distinct().collect() + assert(times.length == 1) + val modTime = times.head.getTimestamp(0) + assert(modTime != null) + // Timestamp.getTime() returns milliseconds + assert(modTime.getTime == expectedModTimeMs) + } + + it("should return correct metadata values per file when reading multiple shapefiles") { + val map1Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map1.shp") + val map2Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val metaRows = df + .select( + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length") + .distinct() + .collect() + assert(metaRows.length == 2) + val byName = metaRows.map(r => r.getString(0) -> r).toMap + // map1.shp + assert(byName("map1.shp").getLong(1) == map1Shp.length()) + assert(byName("map1.shp").getLong(2) == 0L) + assert(byName("map1.shp").getLong(3) == map1Shp.length()) + // map2.shp + assert(byName("map2.shp").getLong(1) == map2Shp.length()) + assert(byName("map2.shp").getLong(2) == 0L) + assert(byName("map2.shp").getLong(3) == map2Shp.length()) + } + + it("should allow filtering on _metadata fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val totalCount = df.count() + val map1Df = df.filter(df("_metadata.file_name") === "map1.shp") + val map2Df = df.filter(df("_metadata.file_name") === "map2.shp") + assert(map1Df.count() > 0) + assert(map2Df.count() > 0) + assert(map1Df.count() + map2Df.count() == totalCount) + } + + it("should select _metadata along with data columns") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val result = df.select("osm_id", "_metadata.file_name").collect() + assert(result.length == 12873) + result.foreach { row => + assert(row.getString(0).nonEmpty) + assert(row.getString(1) == "gis_osm_pois_free_1.shp") + } + } + + it("should return correct metadata for each file in multi-shapefile directory") { + val dt1Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val dt2Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val result = df + .select( + "_metadata.file_path", + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length", + "_metadata.file_modification_time") + .distinct() + .collect() + assert(result.length == 2) + val byName = result.map(r => r.getString(1) -> r).toMap + // datatypes1.shp + val r1 = byName("datatypes1.shp") + assert(r1.getString(0).endsWith("datatypes1.shp")) + assert(r1.getLong(2) == dt1Shp.length()) + assert(r1.getLong(3) == 0L) + assert(r1.getLong(4) == dt1Shp.length()) + assert(r1.getTimestamp(5).getTime == dt1Shp.lastModified()) + // datatypes2.shp + val r2 = byName("datatypes2.shp") + assert(r2.getString(0).endsWith("datatypes2.shp")) + assert(r2.getLong(2) == dt2Shp.length()) + assert(r2.getLong(3) == 0L) + assert(r2.getLong(4) == dt2Shp.length()) + assert(r2.getTimestamp(5).getTime == dt2Shp.lastModified()) + } } } diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala index 5a28af6d66..79c0638bd6 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -18,8 +18,11 @@ */ package org.apache.sedona.sql.datasources.shapefile +import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory @@ -28,14 +31,19 @@ import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitio import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration +import java.util.Locale + case class ShapefilePartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: ShapefileReadOptions, filters: Seq[Filter]) extends PartitionReaderFactory { @@ -48,11 +56,51 @@ case class ShapefilePartitionReaderFactory( partitionedFiles, readDataSchema, options) - new PartitionReaderWithPartitionValues( + val withPartitionValues = new PartitionReaderWithPartitionValues( fileReader, readDataSchema, partitionSchema, partitionedFiles.head.partitionValues) + + if (metadataSchema.nonEmpty) { + // Build metadata values from the .shp file's partition information. + // We use the .shp file because it is the primary shapefile component and its path + // is what users would expect to see in _metadata.file_path / _metadata.file_name. + val shpFile = partitionedFiles + .find(_.filePath.toPath.getName.toLowerCase(Locale.ROOT).endsWith(".shp")) + .getOrElse(partitionedFiles.head) + val filePath = shpFile.filePath.toString + val fileName = new Path(filePath).getName + + // Complete map of all metadata field values keyed by field name. + // The modificationTime from PartitionedFile is in milliseconds but Spark's + // TimestampType uses microseconds, so we multiply by 1000. + val allMetadataValues: Map[String, Any] = Map( + "file_path" -> UTF8String.fromString(filePath), + "file_name" -> UTF8String.fromString(fileName), + "file_size" -> shpFile.fileSize, + "file_block_start" -> shpFile.start, + "file_block_length" -> shpFile.length, + "file_modification_time" -> (shpFile.modificationTime * 1000L)) + + // The metadataSchema may be pruned by Spark's column pruning (e.g., when the query + // only selects `_metadata.file_name`). We must construct the inner struct to match + // the pruned schema exactly, otherwise field ordinals will be misaligned. + val innerStructType = metadataSchema.fields.head.dataType.asInstanceOf[StructType] + val prunedValues = innerStructType.fields.map(f => allMetadataValues(f.name)) + val metadataStruct = InternalRow.fromSeq(prunedValues.toSeq) + + // Wrap the struct in an outer row since _metadata is a single StructType column + val metadataRow = InternalRow.fromSeq(Seq(metadataStruct)) + val baseSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) + new PartitionReaderWithMetadata( + withPartitionValues, + baseSchema, + metadataSchema, + metadataRow) + } else { + withPartitionValues + } } override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -64,3 +112,43 @@ case class ShapefilePartitionReaderFactory( } } } + +/** + * Wraps a partition reader to append metadata column values to each row. This follows the same + * pattern as [[PartitionReaderWithPartitionValues]] but for metadata columns: it uses a + * [[JoinedRow]] to concatenate the base row (data + partition values) with the metadata row, then + * projects the combined row through an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeProjection]] to produce a compact unsafe row. + * + * @param reader + * the underlying reader that produces data + partition value rows + * @param baseSchema + * the combined schema of data columns and partition columns + * @param metadataSchema + * the schema of the metadata columns being appended + * @param metadataValues + * the constant metadata values to append to every row + */ +private[shapefile] class PartitionReaderWithMetadata( + reader: PartitionReader[InternalRow], + baseSchema: StructType, + metadataSchema: StructType, + metadataValues: InternalRow) + extends PartitionReader[InternalRow] { + + private val joinedRow = new JoinedRow() + private val unsafeProjection = + GenerateUnsafeProjection.generate(baseSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, f.nullable) + } ++ metadataSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(baseSchema.length + i, f.dataType, f.nullable) + }) + + override def next(): Boolean = reader.next() + + override def get(): InternalRow = { + unsafeProjection(joinedRow(reader.get(), metadataValues)) + } + + override def close(): Unit = reader.close() +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala index afdface894..3f6a9224aa 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -44,12 +44,22 @@ case class ShapefileScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + /** + * Returns the complete read schema including data columns, partition columns, and any requested + * metadata columns. Metadata columns are appended last so the reader factory can construct a + * [[JoinedRow]] that appends metadata values after data and partition values. + */ + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields ++ metadataSchema.fields) + override def createReaderFactory(): PartitionReaderFactory = { val caseSensitiveMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) @@ -61,6 +71,7 @@ case class ShapefileScan( dataSchema, readDataSchema, readPartitionSchema, + metadataSchema, ShapefileReadOptions.parse(options), pushedFilters) } diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala index e5135e381d..48b5e45d53 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -33,6 +33,28 @@ case class ShapefileScanBuilder( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + /** + * Tracks any metadata fields (e.g., from `_metadata`) requested in the query. Populated by + * [[pruneColumns]] when Spark pushes down column projections. + */ + private var _requiredMetadataSchema: StructType = StructType(Seq.empty) + + /** + * Intercepts Spark's column pruning to separate metadata columns from data/partition columns. + * Fields in [[requiredSchema]] that do not belong to the data schema or partition schema are + * assumed to be metadata fields (e.g., `_metadata`). These are captured in + * [[_requiredMetadataSchema]] so the scan can include them in the output. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + val resolver = sparkSession.sessionState.conf.resolver + val metaFields = requiredSchema.fields.filter { field => + !dataSchema.fields.exists(df => resolver(df.name, field.name)) && + !fileIndex.partitionSchema.fields.exists(pf => resolver(pf.name, field.name)) + } + _requiredMetadataSchema = StructType(metaFields) + super.pruneColumns(requiredSchema) + } + override def build(): Scan = { ShapefileScan( sparkSession, @@ -40,6 +62,7 @@ case class ShapefileScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), + _requiredMetadataSchema, options, pushedDataFilters, partitionFilters, diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala index 7db6bb8d1f..1f4bd093b9 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -21,19 +21,27 @@ package org.apache.sedona.sql.datasources.shapefile import org.apache.hadoop.fs.FileStatus import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import java.util.Locale import scala.collection.JavaConverters._ +/** + * A Spark DataSource V2 table implementation for reading Shapefiles. + * + * Extends [[FileTable]] to leverage Spark's file-based scan infrastructure and implements + * [[SupportsMetadataColumns]] to expose hidden metadata columns (e.g., `_metadata`) that provide + * file-level information such as path, name, size, and modification time. These metadata columns + * are not part of the user-visible schema but can be explicitly selected in queries. + */ case class ShapefileTable( name: String, sparkSession: SparkSession, @@ -41,7 +49,8 @@ case class ShapefileTable( paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) - extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsMetadataColumns { override def formatName: String = "Shapefile" @@ -95,9 +104,49 @@ case class ShapefileTable( } } + /** Returns the metadata columns that this table exposes as hidden columns. */ + override def metadataColumns(): Array[MetadataColumn] = ShapefileTable.fileMetadataColumns + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null } + +object ShapefileTable { + + /** + * Schema of the `_metadata` struct column exposed by [[SupportsMetadataColumns]]. Each field + * provides file-level information about the source shapefile: + * + * - `file_path`: The fully qualified path of the `.shp` file (e.g., + * `hdfs://host/data/file.shp`). + * - `file_name`: The name of the `.shp` file without directory components (e.g., `file.shp`). + * - `file_size`: The total size of the `.shp` file in bytes. + * - `file_block_start`: The byte offset within the file where this partition's data begins. + * For non-splittable formats this is typically 0. + * - `file_block_length`: The number of bytes in this partition's data block. For + * non-splittable formats this equals the file size. + * - `file_modification_time`: The last modification timestamp of the `.shp` file. + */ + private val FILE_METADATA_STRUCT_TYPE: StructType = StructType( + Seq( + StructField("file_path", StringType, nullable = false), + StructField("file_name", StringType, nullable = false), + StructField("file_size", LongType, nullable = false), + StructField("file_block_start", LongType, nullable = false), + StructField("file_block_length", LongType, nullable = false), + StructField("file_modification_time", TimestampType, nullable = false))) + + /** + * The single metadata column `_metadata` exposed to Spark's catalog. This hidden column can be + * selected in queries (e.g., `SELECT _metadata.file_name FROM shapefile.`...``) but does not + * appear in `SELECT *`. + */ + private[shapefile] val fileMetadataColumns: Array[MetadataColumn] = Array(new MetadataColumn { + override def name: String = "_metadata" + override def dataType: DataType = FILE_METADATA_STRUCT_TYPE + override def isNullable: Boolean = false + }) +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 275bc3282f..531eaf65f0 100644 --- a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,8 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} import org.scalatest.BeforeAndAfterAll @@ -502,7 +503,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { .format("shapefile") .load(temporaryLocation) .select("part", "id", "aInt", "aUnicode", "geometry") - var rows = shapefileDf.collect() + val rows = shapefileDf.collect() assert(rows.length == 9) rows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) @@ -517,18 +518,6 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { assert(row.getAs[String]("aUnicode") == s"测试$id") } } - - // Using partition filters - rows = shapefileDf.where("part = 2").collect() - assert(rows.length == 4) - rows.foreach { row => - assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) - assert(row.getAs[Int]("part") == 2) - val id = row.getAs[Long]("id") - assert(id > 10) - assert(row.getAs[Long]("aInt") == id) - assert(row.getAs[String]("aUnicode") == s"测试$id") - } } it("read with recursiveFileLookup") { @@ -780,5 +769,180 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } } } + + it("should expose _metadata struct with all expected fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaDf = df.select("_metadata") + val metaSchema = metaDf.schema("_metadata").dataType.asInstanceOf[StructType] + val expectedFields = + Seq( + "file_path", + "file_name", + "file_size", + "file_block_start", + "file_block_length", + "file_modification_time") + assert(metaSchema.fieldNames.toSeq == expectedFields) + assert(metaSchema("file_path").dataType == StringType) + assert(metaSchema("file_name").dataType == StringType) + assert(metaSchema("file_size").dataType == LongType) + assert(metaSchema("file_block_start").dataType == LongType) + assert(metaSchema("file_block_length").dataType == LongType) + assert(metaSchema("file_modification_time").dataType == TimestampType) + } + + it("should not include _metadata in select(*)") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val columns = df.columns + assert(!columns.contains("_metadata")) + } + + it("should return correct file_path and file_name in _metadata") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaRows = df.select("_metadata.file_path", "_metadata.file_name").distinct().collect() + assert(metaRows.length == 1) + val filePath = metaRows.head.getString(0) + val fileName = metaRows.head.getString(1) + assert(filePath.endsWith("gis_osm_pois_free_1.shp")) + assert(fileName == "gis_osm_pois_free_1.shp") + } + + it("should return actual file_size matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val sizes = df.select("_metadata.file_size").distinct().collect() + assert(sizes.length == 1) + assert(sizes.head.getLong(0) == expectedSize) + } + + it( + "should return file_block_start=0 and file_block_length=file_size for non-splittable shapefiles") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val rows = df + .select("_metadata.file_block_start", "_metadata.file_block_length") + .distinct() + .collect() + assert(rows.length == 1) + assert(rows.head.getLong(0) == 0L) // file_block_start + assert(rows.head.getLong(1) == expectedSize) // file_block_length + } + + it("should return file_modification_time matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + // File.lastModified() returns milliseconds, Spark TimestampType stores microseconds + val expectedModTimeMs = shpFile.lastModified() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val times = + df.select("_metadata.file_modification_time").distinct().collect() + assert(times.length == 1) + val modTime = times.head.getTimestamp(0) + assert(modTime != null) + // Timestamp.getTime() returns milliseconds + assert(modTime.getTime == expectedModTimeMs) + } + + it("should return correct metadata values per file when reading multiple shapefiles") { + val map1Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map1.shp") + val map2Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val metaRows = df + .select( + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length") + .distinct() + .collect() + assert(metaRows.length == 2) + val byName = metaRows.map(r => r.getString(0) -> r).toMap + // map1.shp + assert(byName("map1.shp").getLong(1) == map1Shp.length()) + assert(byName("map1.shp").getLong(2) == 0L) + assert(byName("map1.shp").getLong(3) == map1Shp.length()) + // map2.shp + assert(byName("map2.shp").getLong(1) == map2Shp.length()) + assert(byName("map2.shp").getLong(2) == 0L) + assert(byName("map2.shp").getLong(3) == map2Shp.length()) + } + + it("should allow filtering on _metadata fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val totalCount = df.count() + val map1Df = df.filter(df("_metadata.file_name") === "map1.shp") + val map2Df = df.filter(df("_metadata.file_name") === "map2.shp") + assert(map1Df.count() > 0) + assert(map2Df.count() > 0) + assert(map1Df.count() + map2Df.count() == totalCount) + } + + it("should select _metadata along with data columns") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val result = df.select("osm_id", "_metadata.file_name").collect() + assert(result.length == 12873) + result.foreach { row => + assert(row.getString(0).nonEmpty) + assert(row.getString(1) == "gis_osm_pois_free_1.shp") + } + } + + it("should return correct metadata for each file in multi-shapefile directory") { + val dt1Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val dt2Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val result = df + .select( + "_metadata.file_path", + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length", + "_metadata.file_modification_time") + .distinct() + .collect() + assert(result.length == 2) + val byName = result.map(r => r.getString(1) -> r).toMap + // datatypes1.shp + val r1 = byName("datatypes1.shp") + assert(r1.getString(0).endsWith("datatypes1.shp")) + assert(r1.getLong(2) == dt1Shp.length()) + assert(r1.getLong(3) == 0L) + assert(r1.getLong(4) == dt1Shp.length()) + assert(r1.getTimestamp(5).getTime == dt1Shp.lastModified()) + // datatypes2.shp + val r2 = byName("datatypes2.shp") + assert(r2.getString(0).endsWith("datatypes2.shp")) + assert(r2.getLong(2) == dt2Shp.length()) + assert(r2.getLong(3) == 0L) + assert(r2.getLong(4) == dt2Shp.length()) + assert(r2.getTimestamp(5).getTime == dt2Shp.lastModified()) + } } } From 0307ce89ec837ac5b6c4f55a116c906b62fc983c Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Sat, 14 Feb 2026 23:31:36 -0800 Subject: [PATCH 2/2] Address PR review comments: remove unused Row import, restore partition filter tests - Remove unused org.apache.spark.sql.Row import from ShapefileTests in all 4 Spark versions (3.4, 3.5, 4.0, 4.1) - Restore accidentally removed partition filter test code in spark-3.5, 4.0, and 4.1 (use val filteredRows instead of var rows reassignment) --- .../org/apache/sedona/sql/ShapefileTests.scala | 1 - .../org/apache/sedona/sql/ShapefileTests.scala | 13 ++++++++++++- .../org/apache/sedona/sql/ShapefileTests.scala | 13 ++++++++++++- .../org/apache/sedona/sql/ShapefileTests.scala | 13 ++++++++++++- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 531eaf65f0..96071b8036 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,6 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 531eaf65f0..4b4f218b36 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,6 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} @@ -518,6 +517,18 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { assert(row.getAs[String]("aUnicode") == s"测试$id") } } + + // Using partition filters + val filteredRows = shapefileDf.where("part = 2").collect() + assert(filteredRows.length == 4) + filteredRows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } } it("read with recursiveFileLookup") { diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 531eaf65f0..4b4f218b36 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,6 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} @@ -518,6 +517,18 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { assert(row.getAs[String]("aUnicode") == s"测试$id") } } + + // Using partition filters + val filteredRows = shapefileDf.where("part = 2").collect() + assert(filteredRows.length == 4) + filteredRows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } } it("read with recursiveFileLookup") { diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 531eaf65f0..4b4f218b36 100644 --- a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,6 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} @@ -518,6 +517,18 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { assert(row.getAs[String]("aUnicode") == s"测试$id") } } + + // Using partition filters + val filteredRows = shapefileDf.where("part = 2").collect() + assert(filteredRows.length == 4) + filteredRows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } } it("read with recursiveFileLookup") {