diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala index 5a28af6d66..79c0638bd6 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -18,8 +18,11 @@ */ package org.apache.sedona.sql.datasources.shapefile +import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory @@ -28,14 +31,19 @@ import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitio import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration +import java.util.Locale + case class ShapefilePartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: ShapefileReadOptions, filters: Seq[Filter]) extends PartitionReaderFactory { @@ -48,11 +56,51 @@ case class ShapefilePartitionReaderFactory( partitionedFiles, readDataSchema, options) - new PartitionReaderWithPartitionValues( + val withPartitionValues = new PartitionReaderWithPartitionValues( fileReader, readDataSchema, partitionSchema, partitionedFiles.head.partitionValues) + + if (metadataSchema.nonEmpty) { + // Build metadata values from the .shp file's partition information. + // We use the .shp file because it is the primary shapefile component and its path + // is what users would expect to see in _metadata.file_path / _metadata.file_name. + val shpFile = partitionedFiles + .find(_.filePath.toPath.getName.toLowerCase(Locale.ROOT).endsWith(".shp")) + .getOrElse(partitionedFiles.head) + val filePath = shpFile.filePath.toString + val fileName = new Path(filePath).getName + + // Complete map of all metadata field values keyed by field name. + // The modificationTime from PartitionedFile is in milliseconds but Spark's + // TimestampType uses microseconds, so we multiply by 1000. + val allMetadataValues: Map[String, Any] = Map( + "file_path" -> UTF8String.fromString(filePath), + "file_name" -> UTF8String.fromString(fileName), + "file_size" -> shpFile.fileSize, + "file_block_start" -> shpFile.start, + "file_block_length" -> shpFile.length, + "file_modification_time" -> (shpFile.modificationTime * 1000L)) + + // The metadataSchema may be pruned by Spark's column pruning (e.g., when the query + // only selects `_metadata.file_name`). We must construct the inner struct to match + // the pruned schema exactly, otherwise field ordinals will be misaligned. + val innerStructType = metadataSchema.fields.head.dataType.asInstanceOf[StructType] + val prunedValues = innerStructType.fields.map(f => allMetadataValues(f.name)) + val metadataStruct = InternalRow.fromSeq(prunedValues.toSeq) + + // Wrap the struct in an outer row since _metadata is a single StructType column + val metadataRow = InternalRow.fromSeq(Seq(metadataStruct)) + val baseSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) + new PartitionReaderWithMetadata( + withPartitionValues, + baseSchema, + metadataSchema, + metadataRow) + } else { + withPartitionValues + } } override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -64,3 +112,43 @@ case class ShapefilePartitionReaderFactory( } } } + +/** + * Wraps a partition reader to append metadata column values to each row. This follows the same + * pattern as [[PartitionReaderWithPartitionValues]] but for metadata columns: it uses a + * [[JoinedRow]] to concatenate the base row (data + partition values) with the metadata row, then + * projects the combined row through an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeProjection]] to produce a compact unsafe row. + * + * @param reader + * the underlying reader that produces data + partition value rows + * @param baseSchema + * the combined schema of data columns and partition columns + * @param metadataSchema + * the schema of the metadata columns being appended + * @param metadataValues + * the constant metadata values to append to every row + */ +private[shapefile] class PartitionReaderWithMetadata( + reader: PartitionReader[InternalRow], + baseSchema: StructType, + metadataSchema: StructType, + metadataValues: InternalRow) + extends PartitionReader[InternalRow] { + + private val joinedRow = new JoinedRow() + private val unsafeProjection = + GenerateUnsafeProjection.generate(baseSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, f.nullable) + } ++ metadataSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(baseSchema.length + i, f.dataType, f.nullable) + }) + + override def next(): Boolean = reader.next() + + override def get(): InternalRow = { + unsafeProjection(joinedRow(reader.get(), metadataValues)) + } + + override def close(): Unit = reader.close() +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala index afdface894..3f6a9224aa 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -44,12 +44,22 @@ case class ShapefileScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + /** + * Returns the complete read schema including data columns, partition columns, and any requested + * metadata columns. Metadata columns are appended last so the reader factory can construct a + * [[JoinedRow]] that appends metadata values after data and partition values. + */ + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields ++ metadataSchema.fields) + override def createReaderFactory(): PartitionReaderFactory = { val caseSensitiveMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) @@ -61,6 +71,7 @@ case class ShapefileScan( dataSchema, readDataSchema, readPartitionSchema, + metadataSchema, ShapefileReadOptions.parse(options), pushedFilters) } diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala index e5135e381d..48b5e45d53 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -33,6 +33,28 @@ case class ShapefileScanBuilder( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + /** + * Tracks any metadata fields (e.g., from `_metadata`) requested in the query. Populated by + * [[pruneColumns]] when Spark pushes down column projections. + */ + private var _requiredMetadataSchema: StructType = StructType(Seq.empty) + + /** + * Intercepts Spark's column pruning to separate metadata columns from data/partition columns. + * Fields in [[requiredSchema]] that do not belong to the data schema or partition schema are + * assumed to be metadata fields (e.g., `_metadata`). These are captured in + * [[_requiredMetadataSchema]] so the scan can include them in the output. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + val resolver = sparkSession.sessionState.conf.resolver + val metaFields = requiredSchema.fields.filter { field => + !dataSchema.fields.exists(df => resolver(df.name, field.name)) && + !fileIndex.partitionSchema.fields.exists(pf => resolver(pf.name, field.name)) + } + _requiredMetadataSchema = StructType(metaFields) + super.pruneColumns(requiredSchema) + } + override def build(): Scan = { ShapefileScan( sparkSession, @@ -40,6 +62,7 @@ case class ShapefileScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), + _requiredMetadataSchema, options, pushedDataFilters, partitionFilters, diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala index 7db6bb8d1f..1f4bd093b9 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -21,19 +21,27 @@ package org.apache.sedona.sql.datasources.shapefile import org.apache.hadoop.fs.FileStatus import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import java.util.Locale import scala.collection.JavaConverters._ +/** + * A Spark DataSource V2 table implementation for reading Shapefiles. + * + * Extends [[FileTable]] to leverage Spark's file-based scan infrastructure and implements + * [[SupportsMetadataColumns]] to expose hidden metadata columns (e.g., `_metadata`) that provide + * file-level information such as path, name, size, and modification time. These metadata columns + * are not part of the user-visible schema but can be explicitly selected in queries. + */ case class ShapefileTable( name: String, sparkSession: SparkSession, @@ -41,7 +49,8 @@ case class ShapefileTable( paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) - extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsMetadataColumns { override def formatName: String = "Shapefile" @@ -95,9 +104,49 @@ case class ShapefileTable( } } + /** Returns the metadata columns that this table exposes as hidden columns. */ + override def metadataColumns(): Array[MetadataColumn] = ShapefileTable.fileMetadataColumns + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null } + +object ShapefileTable { + + /** + * Schema of the `_metadata` struct column exposed by [[SupportsMetadataColumns]]. Each field + * provides file-level information about the source shapefile: + * + * - `file_path`: The fully qualified path of the `.shp` file (e.g., + * `hdfs://host/data/file.shp`). + * - `file_name`: The name of the `.shp` file without directory components (e.g., `file.shp`). + * - `file_size`: The total size of the `.shp` file in bytes. + * - `file_block_start`: The byte offset within the file where this partition's data begins. + * For non-splittable formats this is typically 0. + * - `file_block_length`: The number of bytes in this partition's data block. For + * non-splittable formats this equals the file size. + * - `file_modification_time`: The last modification timestamp of the `.shp` file. + */ + private val FILE_METADATA_STRUCT_TYPE: StructType = StructType( + Seq( + StructField("file_path", StringType, nullable = false), + StructField("file_name", StringType, nullable = false), + StructField("file_size", LongType, nullable = false), + StructField("file_block_start", LongType, nullable = false), + StructField("file_block_length", LongType, nullable = false), + StructField("file_modification_time", TimestampType, nullable = false))) + + /** + * The single metadata column `_metadata` exposed to Spark's catalog. This hidden column can be + * selected in queries (e.g., `SELECT _metadata.file_name FROM shapefile.`...``) but does not + * appear in `SELECT *`. + */ + private[shapefile] val fileMetadataColumns: Array[MetadataColumn] = Array(new MetadataColumn { + override def name: String = "_metadata" + override def dataType: DataType = FILE_METADATA_STRUCT_TYPE + override def isNullable: Boolean = false + }) +} diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index c2e88c469b..96071b8036 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,7 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} import org.scalatest.BeforeAndAfterAll @@ -768,5 +768,180 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } } } + + it("should expose _metadata struct with all expected fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaDf = df.select("_metadata") + val metaSchema = metaDf.schema("_metadata").dataType.asInstanceOf[StructType] + val expectedFields = + Seq( + "file_path", + "file_name", + "file_size", + "file_block_start", + "file_block_length", + "file_modification_time") + assert(metaSchema.fieldNames.toSeq == expectedFields) + assert(metaSchema("file_path").dataType == StringType) + assert(metaSchema("file_name").dataType == StringType) + assert(metaSchema("file_size").dataType == LongType) + assert(metaSchema("file_block_start").dataType == LongType) + assert(metaSchema("file_block_length").dataType == LongType) + assert(metaSchema("file_modification_time").dataType == TimestampType) + } + + it("should not include _metadata in select(*)") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val columns = df.columns + assert(!columns.contains("_metadata")) + } + + it("should return correct file_path and file_name in _metadata") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaRows = df.select("_metadata.file_path", "_metadata.file_name").distinct().collect() + assert(metaRows.length == 1) + val filePath = metaRows.head.getString(0) + val fileName = metaRows.head.getString(1) + assert(filePath.endsWith("gis_osm_pois_free_1.shp")) + assert(fileName == "gis_osm_pois_free_1.shp") + } + + it("should return actual file_size matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val sizes = df.select("_metadata.file_size").distinct().collect() + assert(sizes.length == 1) + assert(sizes.head.getLong(0) == expectedSize) + } + + it( + "should return file_block_start=0 and file_block_length=file_size for non-splittable shapefiles") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val rows = df + .select("_metadata.file_block_start", "_metadata.file_block_length") + .distinct() + .collect() + assert(rows.length == 1) + assert(rows.head.getLong(0) == 0L) // file_block_start + assert(rows.head.getLong(1) == expectedSize) // file_block_length + } + + it("should return file_modification_time matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + // File.lastModified() returns milliseconds, Spark TimestampType stores microseconds + val expectedModTimeMs = shpFile.lastModified() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val times = + df.select("_metadata.file_modification_time").distinct().collect() + assert(times.length == 1) + val modTime = times.head.getTimestamp(0) + assert(modTime != null) + // Timestamp.getTime() returns milliseconds + assert(modTime.getTime == expectedModTimeMs) + } + + it("should return correct metadata values per file when reading multiple shapefiles") { + val map1Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map1.shp") + val map2Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val metaRows = df + .select( + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length") + .distinct() + .collect() + assert(metaRows.length == 2) + val byName = metaRows.map(r => r.getString(0) -> r).toMap + // map1.shp + assert(byName("map1.shp").getLong(1) == map1Shp.length()) + assert(byName("map1.shp").getLong(2) == 0L) + assert(byName("map1.shp").getLong(3) == map1Shp.length()) + // map2.shp + assert(byName("map2.shp").getLong(1) == map2Shp.length()) + assert(byName("map2.shp").getLong(2) == 0L) + assert(byName("map2.shp").getLong(3) == map2Shp.length()) + } + + it("should allow filtering on _metadata fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val totalCount = df.count() + val map1Df = df.filter(df("_metadata.file_name") === "map1.shp") + val map2Df = df.filter(df("_metadata.file_name") === "map2.shp") + assert(map1Df.count() > 0) + assert(map2Df.count() > 0) + assert(map1Df.count() + map2Df.count() == totalCount) + } + + it("should select _metadata along with data columns") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val result = df.select("osm_id", "_metadata.file_name").collect() + assert(result.length == 12873) + result.foreach { row => + assert(row.getString(0).nonEmpty) + assert(row.getString(1) == "gis_osm_pois_free_1.shp") + } + } + + it("should return correct metadata for each file in multi-shapefile directory") { + val dt1Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val dt2Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val result = df + .select( + "_metadata.file_path", + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length", + "_metadata.file_modification_time") + .distinct() + .collect() + assert(result.length == 2) + val byName = result.map(r => r.getString(1) -> r).toMap + // datatypes1.shp + val r1 = byName("datatypes1.shp") + assert(r1.getString(0).endsWith("datatypes1.shp")) + assert(r1.getLong(2) == dt1Shp.length()) + assert(r1.getLong(3) == 0L) + assert(r1.getLong(4) == dt1Shp.length()) + assert(r1.getTimestamp(5).getTime == dt1Shp.lastModified()) + // datatypes2.shp + val r2 = byName("datatypes2.shp") + assert(r2.getString(0).endsWith("datatypes2.shp")) + assert(r2.getLong(2) == dt2Shp.length()) + assert(r2.getLong(3) == 0L) + assert(r2.getLong(4) == dt2Shp.length()) + assert(r2.getTimestamp(5).getTime == dt2Shp.lastModified()) + } } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala index 5a28af6d66..79c0638bd6 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -18,8 +18,11 @@ */ package org.apache.sedona.sql.datasources.shapefile +import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory @@ -28,14 +31,19 @@ import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitio import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration +import java.util.Locale + case class ShapefilePartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: ShapefileReadOptions, filters: Seq[Filter]) extends PartitionReaderFactory { @@ -48,11 +56,51 @@ case class ShapefilePartitionReaderFactory( partitionedFiles, readDataSchema, options) - new PartitionReaderWithPartitionValues( + val withPartitionValues = new PartitionReaderWithPartitionValues( fileReader, readDataSchema, partitionSchema, partitionedFiles.head.partitionValues) + + if (metadataSchema.nonEmpty) { + // Build metadata values from the .shp file's partition information. + // We use the .shp file because it is the primary shapefile component and its path + // is what users would expect to see in _metadata.file_path / _metadata.file_name. + val shpFile = partitionedFiles + .find(_.filePath.toPath.getName.toLowerCase(Locale.ROOT).endsWith(".shp")) + .getOrElse(partitionedFiles.head) + val filePath = shpFile.filePath.toString + val fileName = new Path(filePath).getName + + // Complete map of all metadata field values keyed by field name. + // The modificationTime from PartitionedFile is in milliseconds but Spark's + // TimestampType uses microseconds, so we multiply by 1000. + val allMetadataValues: Map[String, Any] = Map( + "file_path" -> UTF8String.fromString(filePath), + "file_name" -> UTF8String.fromString(fileName), + "file_size" -> shpFile.fileSize, + "file_block_start" -> shpFile.start, + "file_block_length" -> shpFile.length, + "file_modification_time" -> (shpFile.modificationTime * 1000L)) + + // The metadataSchema may be pruned by Spark's column pruning (e.g., when the query + // only selects `_metadata.file_name`). We must construct the inner struct to match + // the pruned schema exactly, otherwise field ordinals will be misaligned. + val innerStructType = metadataSchema.fields.head.dataType.asInstanceOf[StructType] + val prunedValues = innerStructType.fields.map(f => allMetadataValues(f.name)) + val metadataStruct = InternalRow.fromSeq(prunedValues.toSeq) + + // Wrap the struct in an outer row since _metadata is a single StructType column + val metadataRow = InternalRow.fromSeq(Seq(metadataStruct)) + val baseSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) + new PartitionReaderWithMetadata( + withPartitionValues, + baseSchema, + metadataSchema, + metadataRow) + } else { + withPartitionValues + } } override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -64,3 +112,43 @@ case class ShapefilePartitionReaderFactory( } } } + +/** + * Wraps a partition reader to append metadata column values to each row. This follows the same + * pattern as [[PartitionReaderWithPartitionValues]] but for metadata columns: it uses a + * [[JoinedRow]] to concatenate the base row (data + partition values) with the metadata row, then + * projects the combined row through an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeProjection]] to produce a compact unsafe row. + * + * @param reader + * the underlying reader that produces data + partition value rows + * @param baseSchema + * the combined schema of data columns and partition columns + * @param metadataSchema + * the schema of the metadata columns being appended + * @param metadataValues + * the constant metadata values to append to every row + */ +private[shapefile] class PartitionReaderWithMetadata( + reader: PartitionReader[InternalRow], + baseSchema: StructType, + metadataSchema: StructType, + metadataValues: InternalRow) + extends PartitionReader[InternalRow] { + + private val joinedRow = new JoinedRow() + private val unsafeProjection = + GenerateUnsafeProjection.generate(baseSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, f.nullable) + } ++ metadataSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(baseSchema.length + i, f.dataType, f.nullable) + }) + + override def next(): Boolean = reader.next() + + override def get(): InternalRow = { + unsafeProjection(joinedRow(reader.get(), metadataValues)) + } + + override def close(): Unit = reader.close() +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala index afdface894..3f6a9224aa 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -44,12 +44,22 @@ case class ShapefileScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + /** + * Returns the complete read schema including data columns, partition columns, and any requested + * metadata columns. Metadata columns are appended last so the reader factory can construct a + * [[JoinedRow]] that appends metadata values after data and partition values. + */ + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields ++ metadataSchema.fields) + override def createReaderFactory(): PartitionReaderFactory = { val caseSensitiveMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) @@ -61,6 +71,7 @@ case class ShapefileScan( dataSchema, readDataSchema, readPartitionSchema, + metadataSchema, ShapefileReadOptions.parse(options), pushedFilters) } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala index e5135e381d..48b5e45d53 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -33,6 +33,28 @@ case class ShapefileScanBuilder( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + /** + * Tracks any metadata fields (e.g., from `_metadata`) requested in the query. Populated by + * [[pruneColumns]] when Spark pushes down column projections. + */ + private var _requiredMetadataSchema: StructType = StructType(Seq.empty) + + /** + * Intercepts Spark's column pruning to separate metadata columns from data/partition columns. + * Fields in [[requiredSchema]] that do not belong to the data schema or partition schema are + * assumed to be metadata fields (e.g., `_metadata`). These are captured in + * [[_requiredMetadataSchema]] so the scan can include them in the output. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + val resolver = sparkSession.sessionState.conf.resolver + val metaFields = requiredSchema.fields.filter { field => + !dataSchema.fields.exists(df => resolver(df.name, field.name)) && + !fileIndex.partitionSchema.fields.exists(pf => resolver(pf.name, field.name)) + } + _requiredMetadataSchema = StructType(metaFields) + super.pruneColumns(requiredSchema) + } + override def build(): Scan = { ShapefileScan( sparkSession, @@ -40,6 +62,7 @@ case class ShapefileScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), + _requiredMetadataSchema, options, pushedDataFilters, partitionFilters, diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala index 7db6bb8d1f..1f4bd093b9 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -21,19 +21,27 @@ package org.apache.sedona.sql.datasources.shapefile import org.apache.hadoop.fs.FileStatus import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import java.util.Locale import scala.collection.JavaConverters._ +/** + * A Spark DataSource V2 table implementation for reading Shapefiles. + * + * Extends [[FileTable]] to leverage Spark's file-based scan infrastructure and implements + * [[SupportsMetadataColumns]] to expose hidden metadata columns (e.g., `_metadata`) that provide + * file-level information such as path, name, size, and modification time. These metadata columns + * are not part of the user-visible schema but can be explicitly selected in queries. + */ case class ShapefileTable( name: String, sparkSession: SparkSession, @@ -41,7 +49,8 @@ case class ShapefileTable( paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) - extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsMetadataColumns { override def formatName: String = "Shapefile" @@ -95,9 +104,49 @@ case class ShapefileTable( } } + /** Returns the metadata columns that this table exposes as hidden columns. */ + override def metadataColumns(): Array[MetadataColumn] = ShapefileTable.fileMetadataColumns + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null } + +object ShapefileTable { + + /** + * Schema of the `_metadata` struct column exposed by [[SupportsMetadataColumns]]. Each field + * provides file-level information about the source shapefile: + * + * - `file_path`: The fully qualified path of the `.shp` file (e.g., + * `hdfs://host/data/file.shp`). + * - `file_name`: The name of the `.shp` file without directory components (e.g., `file.shp`). + * - `file_size`: The total size of the `.shp` file in bytes. + * - `file_block_start`: The byte offset within the file where this partition's data begins. + * For non-splittable formats this is typically 0. + * - `file_block_length`: The number of bytes in this partition's data block. For + * non-splittable formats this equals the file size. + * - `file_modification_time`: The last modification timestamp of the `.shp` file. + */ + private val FILE_METADATA_STRUCT_TYPE: StructType = StructType( + Seq( + StructField("file_path", StringType, nullable = false), + StructField("file_name", StringType, nullable = false), + StructField("file_size", LongType, nullable = false), + StructField("file_block_start", LongType, nullable = false), + StructField("file_block_length", LongType, nullable = false), + StructField("file_modification_time", TimestampType, nullable = false))) + + /** + * The single metadata column `_metadata` exposed to Spark's catalog. This hidden column can be + * selected in queries (e.g., `SELECT _metadata.file_name FROM shapefile.`...``) but does not + * appear in `SELECT *`. + */ + private[shapefile] val fileMetadataColumns: Array[MetadataColumn] = Array(new MetadataColumn { + override def name: String = "_metadata" + override def dataType: DataType = FILE_METADATA_STRUCT_TYPE + override def isNullable: Boolean = false + }) +} diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 275bc3282f..4b4f218b36 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,7 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} import org.scalatest.BeforeAndAfterAll @@ -502,7 +502,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { .format("shapefile") .load(temporaryLocation) .select("part", "id", "aInt", "aUnicode", "geometry") - var rows = shapefileDf.collect() + val rows = shapefileDf.collect() assert(rows.length == 9) rows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) @@ -519,9 +519,9 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } // Using partition filters - rows = shapefileDf.where("part = 2").collect() - assert(rows.length == 4) - rows.foreach { row => + val filteredRows = shapefileDf.where("part = 2").collect() + assert(filteredRows.length == 4) + filteredRows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) assert(row.getAs[Int]("part") == 2) val id = row.getAs[Long]("id") @@ -780,5 +780,180 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } } } + + it("should expose _metadata struct with all expected fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaDf = df.select("_metadata") + val metaSchema = metaDf.schema("_metadata").dataType.asInstanceOf[StructType] + val expectedFields = + Seq( + "file_path", + "file_name", + "file_size", + "file_block_start", + "file_block_length", + "file_modification_time") + assert(metaSchema.fieldNames.toSeq == expectedFields) + assert(metaSchema("file_path").dataType == StringType) + assert(metaSchema("file_name").dataType == StringType) + assert(metaSchema("file_size").dataType == LongType) + assert(metaSchema("file_block_start").dataType == LongType) + assert(metaSchema("file_block_length").dataType == LongType) + assert(metaSchema("file_modification_time").dataType == TimestampType) + } + + it("should not include _metadata in select(*)") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val columns = df.columns + assert(!columns.contains("_metadata")) + } + + it("should return correct file_path and file_name in _metadata") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaRows = df.select("_metadata.file_path", "_metadata.file_name").distinct().collect() + assert(metaRows.length == 1) + val filePath = metaRows.head.getString(0) + val fileName = metaRows.head.getString(1) + assert(filePath.endsWith("gis_osm_pois_free_1.shp")) + assert(fileName == "gis_osm_pois_free_1.shp") + } + + it("should return actual file_size matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val sizes = df.select("_metadata.file_size").distinct().collect() + assert(sizes.length == 1) + assert(sizes.head.getLong(0) == expectedSize) + } + + it( + "should return file_block_start=0 and file_block_length=file_size for non-splittable shapefiles") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val rows = df + .select("_metadata.file_block_start", "_metadata.file_block_length") + .distinct() + .collect() + assert(rows.length == 1) + assert(rows.head.getLong(0) == 0L) // file_block_start + assert(rows.head.getLong(1) == expectedSize) // file_block_length + } + + it("should return file_modification_time matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + // File.lastModified() returns milliseconds, Spark TimestampType stores microseconds + val expectedModTimeMs = shpFile.lastModified() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val times = + df.select("_metadata.file_modification_time").distinct().collect() + assert(times.length == 1) + val modTime = times.head.getTimestamp(0) + assert(modTime != null) + // Timestamp.getTime() returns milliseconds + assert(modTime.getTime == expectedModTimeMs) + } + + it("should return correct metadata values per file when reading multiple shapefiles") { + val map1Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map1.shp") + val map2Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val metaRows = df + .select( + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length") + .distinct() + .collect() + assert(metaRows.length == 2) + val byName = metaRows.map(r => r.getString(0) -> r).toMap + // map1.shp + assert(byName("map1.shp").getLong(1) == map1Shp.length()) + assert(byName("map1.shp").getLong(2) == 0L) + assert(byName("map1.shp").getLong(3) == map1Shp.length()) + // map2.shp + assert(byName("map2.shp").getLong(1) == map2Shp.length()) + assert(byName("map2.shp").getLong(2) == 0L) + assert(byName("map2.shp").getLong(3) == map2Shp.length()) + } + + it("should allow filtering on _metadata fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val totalCount = df.count() + val map1Df = df.filter(df("_metadata.file_name") === "map1.shp") + val map2Df = df.filter(df("_metadata.file_name") === "map2.shp") + assert(map1Df.count() > 0) + assert(map2Df.count() > 0) + assert(map1Df.count() + map2Df.count() == totalCount) + } + + it("should select _metadata along with data columns") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val result = df.select("osm_id", "_metadata.file_name").collect() + assert(result.length == 12873) + result.foreach { row => + assert(row.getString(0).nonEmpty) + assert(row.getString(1) == "gis_osm_pois_free_1.shp") + } + } + + it("should return correct metadata for each file in multi-shapefile directory") { + val dt1Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val dt2Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val result = df + .select( + "_metadata.file_path", + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length", + "_metadata.file_modification_time") + .distinct() + .collect() + assert(result.length == 2) + val byName = result.map(r => r.getString(1) -> r).toMap + // datatypes1.shp + val r1 = byName("datatypes1.shp") + assert(r1.getString(0).endsWith("datatypes1.shp")) + assert(r1.getLong(2) == dt1Shp.length()) + assert(r1.getLong(3) == 0L) + assert(r1.getLong(4) == dt1Shp.length()) + assert(r1.getTimestamp(5).getTime == dt1Shp.lastModified()) + // datatypes2.shp + val r2 = byName("datatypes2.shp") + assert(r2.getString(0).endsWith("datatypes2.shp")) + assert(r2.getLong(2) == dt2Shp.length()) + assert(r2.getLong(3) == 0L) + assert(r2.getLong(4) == dt2Shp.length()) + assert(r2.getTimestamp(5).getTime == dt2Shp.lastModified()) + } } } diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala index 5a28af6d66..79c0638bd6 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -18,8 +18,11 @@ */ package org.apache.sedona.sql.datasources.shapefile +import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory @@ -28,14 +31,19 @@ import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitio import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration +import java.util.Locale + case class ShapefilePartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: ShapefileReadOptions, filters: Seq[Filter]) extends PartitionReaderFactory { @@ -48,11 +56,51 @@ case class ShapefilePartitionReaderFactory( partitionedFiles, readDataSchema, options) - new PartitionReaderWithPartitionValues( + val withPartitionValues = new PartitionReaderWithPartitionValues( fileReader, readDataSchema, partitionSchema, partitionedFiles.head.partitionValues) + + if (metadataSchema.nonEmpty) { + // Build metadata values from the .shp file's partition information. + // We use the .shp file because it is the primary shapefile component and its path + // is what users would expect to see in _metadata.file_path / _metadata.file_name. + val shpFile = partitionedFiles + .find(_.filePath.toPath.getName.toLowerCase(Locale.ROOT).endsWith(".shp")) + .getOrElse(partitionedFiles.head) + val filePath = shpFile.filePath.toString + val fileName = new Path(filePath).getName + + // Complete map of all metadata field values keyed by field name. + // The modificationTime from PartitionedFile is in milliseconds but Spark's + // TimestampType uses microseconds, so we multiply by 1000. + val allMetadataValues: Map[String, Any] = Map( + "file_path" -> UTF8String.fromString(filePath), + "file_name" -> UTF8String.fromString(fileName), + "file_size" -> shpFile.fileSize, + "file_block_start" -> shpFile.start, + "file_block_length" -> shpFile.length, + "file_modification_time" -> (shpFile.modificationTime * 1000L)) + + // The metadataSchema may be pruned by Spark's column pruning (e.g., when the query + // only selects `_metadata.file_name`). We must construct the inner struct to match + // the pruned schema exactly, otherwise field ordinals will be misaligned. + val innerStructType = metadataSchema.fields.head.dataType.asInstanceOf[StructType] + val prunedValues = innerStructType.fields.map(f => allMetadataValues(f.name)) + val metadataStruct = InternalRow.fromSeq(prunedValues.toSeq) + + // Wrap the struct in an outer row since _metadata is a single StructType column + val metadataRow = InternalRow.fromSeq(Seq(metadataStruct)) + val baseSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) + new PartitionReaderWithMetadata( + withPartitionValues, + baseSchema, + metadataSchema, + metadataRow) + } else { + withPartitionValues + } } override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -64,3 +112,43 @@ case class ShapefilePartitionReaderFactory( } } } + +/** + * Wraps a partition reader to append metadata column values to each row. This follows the same + * pattern as [[PartitionReaderWithPartitionValues]] but for metadata columns: it uses a + * [[JoinedRow]] to concatenate the base row (data + partition values) with the metadata row, then + * projects the combined row through an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeProjection]] to produce a compact unsafe row. + * + * @param reader + * the underlying reader that produces data + partition value rows + * @param baseSchema + * the combined schema of data columns and partition columns + * @param metadataSchema + * the schema of the metadata columns being appended + * @param metadataValues + * the constant metadata values to append to every row + */ +private[shapefile] class PartitionReaderWithMetadata( + reader: PartitionReader[InternalRow], + baseSchema: StructType, + metadataSchema: StructType, + metadataValues: InternalRow) + extends PartitionReader[InternalRow] { + + private val joinedRow = new JoinedRow() + private val unsafeProjection = + GenerateUnsafeProjection.generate(baseSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, f.nullable) + } ++ metadataSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(baseSchema.length + i, f.dataType, f.nullable) + }) + + override def next(): Boolean = reader.next() + + override def get(): InternalRow = { + unsafeProjection(joinedRow(reader.get(), metadataValues)) + } + + override def close(): Unit = reader.close() +} diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala index afdface894..3f6a9224aa 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -44,12 +44,22 @@ case class ShapefileScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + /** + * Returns the complete read schema including data columns, partition columns, and any requested + * metadata columns. Metadata columns are appended last so the reader factory can construct a + * [[JoinedRow]] that appends metadata values after data and partition values. + */ + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields ++ metadataSchema.fields) + override def createReaderFactory(): PartitionReaderFactory = { val caseSensitiveMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) @@ -61,6 +71,7 @@ case class ShapefileScan( dataSchema, readDataSchema, readPartitionSchema, + metadataSchema, ShapefileReadOptions.parse(options), pushedFilters) } diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala index e5135e381d..48b5e45d53 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -33,6 +33,28 @@ case class ShapefileScanBuilder( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + /** + * Tracks any metadata fields (e.g., from `_metadata`) requested in the query. Populated by + * [[pruneColumns]] when Spark pushes down column projections. + */ + private var _requiredMetadataSchema: StructType = StructType(Seq.empty) + + /** + * Intercepts Spark's column pruning to separate metadata columns from data/partition columns. + * Fields in [[requiredSchema]] that do not belong to the data schema or partition schema are + * assumed to be metadata fields (e.g., `_metadata`). These are captured in + * [[_requiredMetadataSchema]] so the scan can include them in the output. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + val resolver = sparkSession.sessionState.conf.resolver + val metaFields = requiredSchema.fields.filter { field => + !dataSchema.fields.exists(df => resolver(df.name, field.name)) && + !fileIndex.partitionSchema.fields.exists(pf => resolver(pf.name, field.name)) + } + _requiredMetadataSchema = StructType(metaFields) + super.pruneColumns(requiredSchema) + } + override def build(): Scan = { ShapefileScan( sparkSession, @@ -40,6 +62,7 @@ case class ShapefileScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), + _requiredMetadataSchema, options, pushedDataFilters, partitionFilters, diff --git a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala index 7db6bb8d1f..1f4bd093b9 100644 --- a/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala +++ b/spark/spark-4.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -21,19 +21,27 @@ package org.apache.sedona.sql.datasources.shapefile import org.apache.hadoop.fs.FileStatus import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import java.util.Locale import scala.collection.JavaConverters._ +/** + * A Spark DataSource V2 table implementation for reading Shapefiles. + * + * Extends [[FileTable]] to leverage Spark's file-based scan infrastructure and implements + * [[SupportsMetadataColumns]] to expose hidden metadata columns (e.g., `_metadata`) that provide + * file-level information such as path, name, size, and modification time. These metadata columns + * are not part of the user-visible schema but can be explicitly selected in queries. + */ case class ShapefileTable( name: String, sparkSession: SparkSession, @@ -41,7 +49,8 @@ case class ShapefileTable( paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) - extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsMetadataColumns { override def formatName: String = "Shapefile" @@ -95,9 +104,49 @@ case class ShapefileTable( } } + /** Returns the metadata columns that this table exposes as hidden columns. */ + override def metadataColumns(): Array[MetadataColumn] = ShapefileTable.fileMetadataColumns + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null } + +object ShapefileTable { + + /** + * Schema of the `_metadata` struct column exposed by [[SupportsMetadataColumns]]. Each field + * provides file-level information about the source shapefile: + * + * - `file_path`: The fully qualified path of the `.shp` file (e.g., + * `hdfs://host/data/file.shp`). + * - `file_name`: The name of the `.shp` file without directory components (e.g., `file.shp`). + * - `file_size`: The total size of the `.shp` file in bytes. + * - `file_block_start`: The byte offset within the file where this partition's data begins. + * For non-splittable formats this is typically 0. + * - `file_block_length`: The number of bytes in this partition's data block. For + * non-splittable formats this equals the file size. + * - `file_modification_time`: The last modification timestamp of the `.shp` file. + */ + private val FILE_METADATA_STRUCT_TYPE: StructType = StructType( + Seq( + StructField("file_path", StringType, nullable = false), + StructField("file_name", StringType, nullable = false), + StructField("file_size", LongType, nullable = false), + StructField("file_block_start", LongType, nullable = false), + StructField("file_block_length", LongType, nullable = false), + StructField("file_modification_time", TimestampType, nullable = false))) + + /** + * The single metadata column `_metadata` exposed to Spark's catalog. This hidden column can be + * selected in queries (e.g., `SELECT _metadata.file_name FROM shapefile.`...``) but does not + * appear in `SELECT *`. + */ + private[shapefile] val fileMetadataColumns: Array[MetadataColumn] = Array(new MetadataColumn { + override def name: String = "_metadata" + override def dataType: DataType = FILE_METADATA_STRUCT_TYPE + override def isNullable: Boolean = false + }) +} diff --git a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 275bc3282f..4b4f218b36 100644 --- a/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,7 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} import org.scalatest.BeforeAndAfterAll @@ -502,7 +502,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { .format("shapefile") .load(temporaryLocation) .select("part", "id", "aInt", "aUnicode", "geometry") - var rows = shapefileDf.collect() + val rows = shapefileDf.collect() assert(rows.length == 9) rows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) @@ -519,9 +519,9 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } // Using partition filters - rows = shapefileDf.where("part = 2").collect() - assert(rows.length == 4) - rows.foreach { row => + val filteredRows = shapefileDf.where("part = 2").collect() + assert(filteredRows.length == 4) + filteredRows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) assert(row.getAs[Int]("part") == 2) val id = row.getAs[Long]("id") @@ -780,5 +780,180 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } } } + + it("should expose _metadata struct with all expected fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaDf = df.select("_metadata") + val metaSchema = metaDf.schema("_metadata").dataType.asInstanceOf[StructType] + val expectedFields = + Seq( + "file_path", + "file_name", + "file_size", + "file_block_start", + "file_block_length", + "file_modification_time") + assert(metaSchema.fieldNames.toSeq == expectedFields) + assert(metaSchema("file_path").dataType == StringType) + assert(metaSchema("file_name").dataType == StringType) + assert(metaSchema("file_size").dataType == LongType) + assert(metaSchema("file_block_start").dataType == LongType) + assert(metaSchema("file_block_length").dataType == LongType) + assert(metaSchema("file_modification_time").dataType == TimestampType) + } + + it("should not include _metadata in select(*)") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val columns = df.columns + assert(!columns.contains("_metadata")) + } + + it("should return correct file_path and file_name in _metadata") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaRows = df.select("_metadata.file_path", "_metadata.file_name").distinct().collect() + assert(metaRows.length == 1) + val filePath = metaRows.head.getString(0) + val fileName = metaRows.head.getString(1) + assert(filePath.endsWith("gis_osm_pois_free_1.shp")) + assert(fileName == "gis_osm_pois_free_1.shp") + } + + it("should return actual file_size matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val sizes = df.select("_metadata.file_size").distinct().collect() + assert(sizes.length == 1) + assert(sizes.head.getLong(0) == expectedSize) + } + + it( + "should return file_block_start=0 and file_block_length=file_size for non-splittable shapefiles") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val rows = df + .select("_metadata.file_block_start", "_metadata.file_block_length") + .distinct() + .collect() + assert(rows.length == 1) + assert(rows.head.getLong(0) == 0L) // file_block_start + assert(rows.head.getLong(1) == expectedSize) // file_block_length + } + + it("should return file_modification_time matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + // File.lastModified() returns milliseconds, Spark TimestampType stores microseconds + val expectedModTimeMs = shpFile.lastModified() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val times = + df.select("_metadata.file_modification_time").distinct().collect() + assert(times.length == 1) + val modTime = times.head.getTimestamp(0) + assert(modTime != null) + // Timestamp.getTime() returns milliseconds + assert(modTime.getTime == expectedModTimeMs) + } + + it("should return correct metadata values per file when reading multiple shapefiles") { + val map1Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map1.shp") + val map2Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val metaRows = df + .select( + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length") + .distinct() + .collect() + assert(metaRows.length == 2) + val byName = metaRows.map(r => r.getString(0) -> r).toMap + // map1.shp + assert(byName("map1.shp").getLong(1) == map1Shp.length()) + assert(byName("map1.shp").getLong(2) == 0L) + assert(byName("map1.shp").getLong(3) == map1Shp.length()) + // map2.shp + assert(byName("map2.shp").getLong(1) == map2Shp.length()) + assert(byName("map2.shp").getLong(2) == 0L) + assert(byName("map2.shp").getLong(3) == map2Shp.length()) + } + + it("should allow filtering on _metadata fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val totalCount = df.count() + val map1Df = df.filter(df("_metadata.file_name") === "map1.shp") + val map2Df = df.filter(df("_metadata.file_name") === "map2.shp") + assert(map1Df.count() > 0) + assert(map2Df.count() > 0) + assert(map1Df.count() + map2Df.count() == totalCount) + } + + it("should select _metadata along with data columns") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val result = df.select("osm_id", "_metadata.file_name").collect() + assert(result.length == 12873) + result.foreach { row => + assert(row.getString(0).nonEmpty) + assert(row.getString(1) == "gis_osm_pois_free_1.shp") + } + } + + it("should return correct metadata for each file in multi-shapefile directory") { + val dt1Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val dt2Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val result = df + .select( + "_metadata.file_path", + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length", + "_metadata.file_modification_time") + .distinct() + .collect() + assert(result.length == 2) + val byName = result.map(r => r.getString(1) -> r).toMap + // datatypes1.shp + val r1 = byName("datatypes1.shp") + assert(r1.getString(0).endsWith("datatypes1.shp")) + assert(r1.getLong(2) == dt1Shp.length()) + assert(r1.getLong(3) == 0L) + assert(r1.getLong(4) == dt1Shp.length()) + assert(r1.getTimestamp(5).getTime == dt1Shp.lastModified()) + // datatypes2.shp + val r2 = byName("datatypes2.shp") + assert(r2.getString(0).endsWith("datatypes2.shp")) + assert(r2.getLong(2) == dt2Shp.length()) + assert(r2.getLong(3) == 0L) + assert(r2.getLong(4) == dt2Shp.length()) + assert(r2.getTimestamp(5).getTime == dt2Shp.lastModified()) + } } } diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala index 5a28af6d66..79c0638bd6 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -18,8 +18,11 @@ */ package org.apache.sedona.sql.datasources.shapefile +import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory @@ -28,14 +31,19 @@ import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitio import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration +import java.util.Locale + case class ShapefilePartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: ShapefileReadOptions, filters: Seq[Filter]) extends PartitionReaderFactory { @@ -48,11 +56,51 @@ case class ShapefilePartitionReaderFactory( partitionedFiles, readDataSchema, options) - new PartitionReaderWithPartitionValues( + val withPartitionValues = new PartitionReaderWithPartitionValues( fileReader, readDataSchema, partitionSchema, partitionedFiles.head.partitionValues) + + if (metadataSchema.nonEmpty) { + // Build metadata values from the .shp file's partition information. + // We use the .shp file because it is the primary shapefile component and its path + // is what users would expect to see in _metadata.file_path / _metadata.file_name. + val shpFile = partitionedFiles + .find(_.filePath.toPath.getName.toLowerCase(Locale.ROOT).endsWith(".shp")) + .getOrElse(partitionedFiles.head) + val filePath = shpFile.filePath.toString + val fileName = new Path(filePath).getName + + // Complete map of all metadata field values keyed by field name. + // The modificationTime from PartitionedFile is in milliseconds but Spark's + // TimestampType uses microseconds, so we multiply by 1000. + val allMetadataValues: Map[String, Any] = Map( + "file_path" -> UTF8String.fromString(filePath), + "file_name" -> UTF8String.fromString(fileName), + "file_size" -> shpFile.fileSize, + "file_block_start" -> shpFile.start, + "file_block_length" -> shpFile.length, + "file_modification_time" -> (shpFile.modificationTime * 1000L)) + + // The metadataSchema may be pruned by Spark's column pruning (e.g., when the query + // only selects `_metadata.file_name`). We must construct the inner struct to match + // the pruned schema exactly, otherwise field ordinals will be misaligned. + val innerStructType = metadataSchema.fields.head.dataType.asInstanceOf[StructType] + val prunedValues = innerStructType.fields.map(f => allMetadataValues(f.name)) + val metadataStruct = InternalRow.fromSeq(prunedValues.toSeq) + + // Wrap the struct in an outer row since _metadata is a single StructType column + val metadataRow = InternalRow.fromSeq(Seq(metadataStruct)) + val baseSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) + new PartitionReaderWithMetadata( + withPartitionValues, + baseSchema, + metadataSchema, + metadataRow) + } else { + withPartitionValues + } } override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -64,3 +112,43 @@ case class ShapefilePartitionReaderFactory( } } } + +/** + * Wraps a partition reader to append metadata column values to each row. This follows the same + * pattern as [[PartitionReaderWithPartitionValues]] but for metadata columns: it uses a + * [[JoinedRow]] to concatenate the base row (data + partition values) with the metadata row, then + * projects the combined row through an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeProjection]] to produce a compact unsafe row. + * + * @param reader + * the underlying reader that produces data + partition value rows + * @param baseSchema + * the combined schema of data columns and partition columns + * @param metadataSchema + * the schema of the metadata columns being appended + * @param metadataValues + * the constant metadata values to append to every row + */ +private[shapefile] class PartitionReaderWithMetadata( + reader: PartitionReader[InternalRow], + baseSchema: StructType, + metadataSchema: StructType, + metadataValues: InternalRow) + extends PartitionReader[InternalRow] { + + private val joinedRow = new JoinedRow() + private val unsafeProjection = + GenerateUnsafeProjection.generate(baseSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, f.nullable) + } ++ metadataSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(baseSchema.length + i, f.dataType, f.nullable) + }) + + override def next(): Boolean = reader.next() + + override def get(): InternalRow = { + unsafeProjection(joinedRow(reader.get(), metadataValues)) + } + + override def close(): Unit = reader.close() +} diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala index afdface894..3f6a9224aa 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -44,12 +44,22 @@ case class ShapefileScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, + /** The metadata fields requested by the query (e.g., fields from `_metadata`). */ + metadataSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + /** + * Returns the complete read schema including data columns, partition columns, and any requested + * metadata columns. Metadata columns are appended last so the reader factory can construct a + * [[JoinedRow]] that appends metadata values after data and partition values. + */ + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields ++ metadataSchema.fields) + override def createReaderFactory(): PartitionReaderFactory = { val caseSensitiveMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) @@ -61,6 +71,7 @@ case class ShapefileScan( dataSchema, readDataSchema, readPartitionSchema, + metadataSchema, ShapefileReadOptions.parse(options), pushedFilters) } diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala index e5135e381d..48b5e45d53 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -33,6 +33,28 @@ case class ShapefileScanBuilder( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + /** + * Tracks any metadata fields (e.g., from `_metadata`) requested in the query. Populated by + * [[pruneColumns]] when Spark pushes down column projections. + */ + private var _requiredMetadataSchema: StructType = StructType(Seq.empty) + + /** + * Intercepts Spark's column pruning to separate metadata columns from data/partition columns. + * Fields in [[requiredSchema]] that do not belong to the data schema or partition schema are + * assumed to be metadata fields (e.g., `_metadata`). These are captured in + * [[_requiredMetadataSchema]] so the scan can include them in the output. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + val resolver = sparkSession.sessionState.conf.resolver + val metaFields = requiredSchema.fields.filter { field => + !dataSchema.fields.exists(df => resolver(df.name, field.name)) && + !fileIndex.partitionSchema.fields.exists(pf => resolver(pf.name, field.name)) + } + _requiredMetadataSchema = StructType(metaFields) + super.pruneColumns(requiredSchema) + } + override def build(): Scan = { ShapefileScan( sparkSession, @@ -40,6 +62,7 @@ case class ShapefileScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), + _requiredMetadataSchema, options, pushedDataFilters, partitionFilters, diff --git a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala index 7db6bb8d1f..1f4bd093b9 100644 --- a/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala +++ b/spark/spark-4.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -21,19 +21,27 @@ package org.apache.sedona.sql.datasources.shapefile import org.apache.hadoop.fs.FileStatus import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import java.util.Locale import scala.collection.JavaConverters._ +/** + * A Spark DataSource V2 table implementation for reading Shapefiles. + * + * Extends [[FileTable]] to leverage Spark's file-based scan infrastructure and implements + * [[SupportsMetadataColumns]] to expose hidden metadata columns (e.g., `_metadata`) that provide + * file-level information such as path, name, size, and modification time. These metadata columns + * are not part of the user-visible schema but can be explicitly selected in queries. + */ case class ShapefileTable( name: String, sparkSession: SparkSession, @@ -41,7 +49,8 @@ case class ShapefileTable( paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) - extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsMetadataColumns { override def formatName: String = "Shapefile" @@ -95,9 +104,49 @@ case class ShapefileTable( } } + /** Returns the metadata columns that this table exposes as hidden columns. */ + override def metadataColumns(): Array[MetadataColumn] = ShapefileTable.fileMetadataColumns + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null } + +object ShapefileTable { + + /** + * Schema of the `_metadata` struct column exposed by [[SupportsMetadataColumns]]. Each field + * provides file-level information about the source shapefile: + * + * - `file_path`: The fully qualified path of the `.shp` file (e.g., + * `hdfs://host/data/file.shp`). + * - `file_name`: The name of the `.shp` file without directory components (e.g., `file.shp`). + * - `file_size`: The total size of the `.shp` file in bytes. + * - `file_block_start`: The byte offset within the file where this partition's data begins. + * For non-splittable formats this is typically 0. + * - `file_block_length`: The number of bytes in this partition's data block. For + * non-splittable formats this equals the file size. + * - `file_modification_time`: The last modification timestamp of the `.shp` file. + */ + private val FILE_METADATA_STRUCT_TYPE: StructType = StructType( + Seq( + StructField("file_path", StringType, nullable = false), + StructField("file_name", StringType, nullable = false), + StructField("file_size", LongType, nullable = false), + StructField("file_block_start", LongType, nullable = false), + StructField("file_block_length", LongType, nullable = false), + StructField("file_modification_time", TimestampType, nullable = false))) + + /** + * The single metadata column `_metadata` exposed to Spark's catalog. This hidden column can be + * selected in queries (e.g., `SELECT _metadata.file_name FROM shapefile.`...``) but does not + * appear in `SELECT *`. + */ + private[shapefile] val fileMetadataColumns: Array[MetadataColumn] = Array(new MetadataColumn { + override def name: String = "_metadata" + override def dataType: DataType = FILE_METADATA_STRUCT_TYPE + override def isNullable: Boolean = false + }) +} diff --git a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala index 275bc3282f..4b4f218b36 100644 --- a/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala +++ b/spark/spark-4.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -20,7 +20,7 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType, TimestampType} import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} import org.locationtech.jts.io.{WKTReader, WKTWriter} import org.scalatest.BeforeAndAfterAll @@ -502,7 +502,7 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { .format("shapefile") .load(temporaryLocation) .select("part", "id", "aInt", "aUnicode", "geometry") - var rows = shapefileDf.collect() + val rows = shapefileDf.collect() assert(rows.length == 9) rows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) @@ -519,9 +519,9 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } // Using partition filters - rows = shapefileDf.where("part = 2").collect() - assert(rows.length == 4) - rows.foreach { row => + val filteredRows = shapefileDf.where("part = 2").collect() + assert(filteredRows.length == 4) + filteredRows.foreach { row => assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) assert(row.getAs[Int]("part") == 2) val id = row.getAs[Long]("id") @@ -780,5 +780,180 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { } } } + + it("should expose _metadata struct with all expected fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaDf = df.select("_metadata") + val metaSchema = metaDf.schema("_metadata").dataType.asInstanceOf[StructType] + val expectedFields = + Seq( + "file_path", + "file_name", + "file_size", + "file_block_start", + "file_block_length", + "file_modification_time") + assert(metaSchema.fieldNames.toSeq == expectedFields) + assert(metaSchema("file_path").dataType == StringType) + assert(metaSchema("file_name").dataType == StringType) + assert(metaSchema("file_size").dataType == LongType) + assert(metaSchema("file_block_start").dataType == LongType) + assert(metaSchema("file_block_length").dataType == LongType) + assert(metaSchema("file_modification_time").dataType == TimestampType) + } + + it("should not include _metadata in select(*)") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val columns = df.columns + assert(!columns.contains("_metadata")) + } + + it("should return correct file_path and file_name in _metadata") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val metaRows = df.select("_metadata.file_path", "_metadata.file_name").distinct().collect() + assert(metaRows.length == 1) + val filePath = metaRows.head.getString(0) + val fileName = metaRows.head.getString(1) + assert(filePath.endsWith("gis_osm_pois_free_1.shp")) + assert(fileName == "gis_osm_pois_free_1.shp") + } + + it("should return actual file_size matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val sizes = df.select("_metadata.file_size").distinct().collect() + assert(sizes.length == 1) + assert(sizes.head.getLong(0) == expectedSize) + } + + it( + "should return file_block_start=0 and file_block_length=file_size for non-splittable shapefiles") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + val expectedSize = shpFile.length() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val rows = df + .select("_metadata.file_block_start", "_metadata.file_block_length") + .distinct() + .collect() + assert(rows.length == 1) + assert(rows.head.getLong(0) == 0L) // file_block_start + assert(rows.head.getLong(1) == expectedSize) // file_block_length + } + + it("should return file_modification_time matching the .shp file on disk") { + val shpFile = + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + // File.lastModified() returns milliseconds, Spark TimestampType stores microseconds + val expectedModTimeMs = shpFile.lastModified() + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val times = + df.select("_metadata.file_modification_time").distinct().collect() + assert(times.length == 1) + val modTime = times.head.getTimestamp(0) + assert(modTime != null) + // Timestamp.getTime() returns milliseconds + assert(modTime.getTime == expectedModTimeMs) + } + + it("should return correct metadata values per file when reading multiple shapefiles") { + val map1Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map1.shp") + val map2Shp = + new File(resourceFolder + "shapefiles/multipleshapefiles/map2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val metaRows = df + .select( + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length") + .distinct() + .collect() + assert(metaRows.length == 2) + val byName = metaRows.map(r => r.getString(0) -> r).toMap + // map1.shp + assert(byName("map1.shp").getLong(1) == map1Shp.length()) + assert(byName("map1.shp").getLong(2) == 0L) + assert(byName("map1.shp").getLong(3) == map1Shp.length()) + // map2.shp + assert(byName("map2.shp").getLong(1) == map2Shp.length()) + assert(byName("map2.shp").getLong(2) == 0L) + assert(byName("map2.shp").getLong(3) == map2Shp.length()) + } + + it("should allow filtering on _metadata fields") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val totalCount = df.count() + val map1Df = df.filter(df("_metadata.file_name") === "map1.shp") + val map2Df = df.filter(df("_metadata.file_name") === "map2.shp") + assert(map1Df.count() > 0) + assert(map2Df.count() > 0) + assert(map1Df.count() + map2Df.count() == totalCount) + } + + it("should select _metadata along with data columns") { + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val result = df.select("osm_id", "_metadata.file_name").collect() + assert(result.length == 12873) + result.foreach { row => + assert(row.getString(0).nonEmpty) + assert(row.getString(1) == "gis_osm_pois_free_1.shp") + } + } + + it("should return correct metadata for each file in multi-shapefile directory") { + val dt1Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val dt2Shp = new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp") + val df = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val result = df + .select( + "_metadata.file_path", + "_metadata.file_name", + "_metadata.file_size", + "_metadata.file_block_start", + "_metadata.file_block_length", + "_metadata.file_modification_time") + .distinct() + .collect() + assert(result.length == 2) + val byName = result.map(r => r.getString(1) -> r).toMap + // datatypes1.shp + val r1 = byName("datatypes1.shp") + assert(r1.getString(0).endsWith("datatypes1.shp")) + assert(r1.getLong(2) == dt1Shp.length()) + assert(r1.getLong(3) == 0L) + assert(r1.getLong(4) == dt1Shp.length()) + assert(r1.getTimestamp(5).getTime == dt1Shp.lastModified()) + // datatypes2.shp + val r2 = byName("datatypes2.shp") + assert(r2.getString(0).endsWith("datatypes2.shp")) + assert(r2.getLong(2) == dt2Shp.length()) + assert(r2.getLong(3) == 0L) + assert(r2.getLong(4) == dt2Shp.length()) + assert(r2.getTimestamp(5).getTime == dt2Shp.lastModified()) + } } }