-
Notifications
You must be signed in to change notification settings - Fork 749
[GH-2360] Support fetching libpostal model data from HDFS/object store #2637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
| package org.apache.spark.sql.sedona_sql.expressions | ||
|
|
||
| import org.apache.spark.SparkFiles | ||
| import org.slf4j.LoggerFactory | ||
|
|
||
| import java.io.File | ||
| import java.net.URI | ||
|
|
||
| /** | ||
| * Resolves libpostal data directory paths. When the configured data directory points to a remote | ||
| * filesystem (HDFS, S3, GCS, ABFS, etc.), the data is expected to have been distributed to | ||
| * executors via `SparkContext.addFile()` and is resolved through `SparkFiles.get()`. | ||
| */ | ||
| object LibPostalDataLoader { | ||
|
|
||
| private val logger = LoggerFactory.getLogger(getClass) | ||
|
|
||
| /** | ||
| * Resolve the data directory to a local filesystem path. If the configured path already points | ||
| * to the local filesystem, it is returned as-is. If it points to a remote filesystem, the data | ||
| * is looked up via Spark's `SparkFiles` mechanism (the user must have called | ||
| * `sc.addFile(remotePath, recursive = true)` before running queries). | ||
| * | ||
| * @param configuredDir | ||
| * the data directory path from Sedona configuration (may be local or remote) | ||
| * @return | ||
| * a local filesystem path suitable for jpostal | ||
| */ | ||
| def resolveDataDir(configuredDir: String): String = { | ||
| if (isRemotePath(configuredDir)) { | ||
| resolveFromSparkFiles(configuredDir) | ||
| } else { | ||
| normalizeLocalPath(configuredDir) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Normalize a local path. Converts `file:` URIs (e.g. `file:///tmp/libpostal`) to plain | ||
| * filesystem paths (`/tmp/libpostal`) so that jpostal receives a path it can use directly. | ||
| * Non-URI paths are returned unchanged. | ||
| */ | ||
| private[expressions] def normalizeLocalPath(path: String): String = { | ||
| try { | ||
| val uri = new URI(path) | ||
| if (uri.getScheme != null && uri.getScheme.equalsIgnoreCase("file")) { | ||
| new File(uri).getAbsolutePath | ||
| } else { | ||
| path | ||
| } | ||
| } catch { | ||
| case _: Exception => path | ||
| } | ||
| } | ||
jiayuasu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| /** | ||
| * Determine whether a path string refers to a remote (non-local) filesystem. | ||
| */ | ||
| def isRemotePath(path: String): Boolean = { | ||
| try { | ||
| val uri = new URI(path) | ||
| val scheme = uri.getScheme | ||
| scheme != null && !scheme.equalsIgnoreCase("file") && scheme.length > 1 | ||
| } catch { | ||
| case _: Exception => false | ||
| } | ||
| } | ||
jiayuasu marked this conversation as resolved.
Show resolved
Hide resolved
jiayuasu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| /** | ||
| * Resolve a remote data directory via Spark's file distribution mechanism. Extracts the | ||
| * basename (last path component) from the remote URI and looks it up through `SparkFiles.get`. | ||
| * The user must have previously called `sc.addFile(remotePath, recursive = true)`. | ||
| * | ||
| * @throws IllegalStateException | ||
| * if the data directory was not found via SparkFiles | ||
| */ | ||
| private def resolveFromSparkFiles(remotePath: String): String = { | ||
| val basename = extractBasename(remotePath) | ||
|
|
||
| try { | ||
| val localPath = SparkFiles.get(basename) | ||
| val localFile = new File(localPath) | ||
|
Comment on lines
+98
to
+99
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in contrast with the docs it seems that this code downloads the data from the hdfs itself.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no it does not. It checks if Spark actually downloads it. See: https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.SparkContext.addFile.html#pyspark.SparkContext.addFile |
||
|
|
||
| if (localFile.exists() && localFile.isDirectory) { | ||
| logger.info( | ||
| "Resolved libpostal data from SparkFiles: {} -> {}", | ||
| remotePath: Any, | ||
| localPath: Any) | ||
| ensureTrailingSlash(localPath) | ||
jiayuasu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } else { | ||
| throw new IllegalStateException( | ||
| s"libpostal data directory '$basename' was not found via SparkFiles. " + | ||
| "Please call sc.addFile(\"" + remotePath + "\", recursive = true) before running libpostal queries.") | ||
| } | ||
| } catch { | ||
| case e: IllegalStateException => throw e | ||
| case e: Exception => | ||
| throw new IllegalStateException( | ||
| s"Failed to resolve libpostal data from SparkFiles for '$remotePath'. " + | ||
| "Please call sc.addFile(\"" + remotePath + "\", recursive = true) before running libpostal queries.", | ||
| e) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Extract the basename (last path component) from a URI string. Trailing slashes are stripped | ||
| * before extracting the last component. | ||
| */ | ||
| private[expressions] def extractBasename(path: String): String = { | ||
| val trimmed = path.replaceAll("/+$", "") | ||
| val uri = new URI(trimmed) | ||
| val uriPath = uri.getPath | ||
| if (uriPath == null || uriPath.isEmpty) { | ||
| trimmed.split("/").last | ||
| } else { | ||
| uriPath.split("/").last | ||
| } | ||
| } | ||
|
|
||
| private def ensureTrailingSlash(path: String): String = { | ||
| if (path.endsWith("/") || path.endsWith(File.separator)) path | ||
| else path + File.separator | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
| package org.apache.sedona.sql | ||
|
|
||
| import org.apache.spark.SparkFiles | ||
| import org.apache.spark.sql.sedona_sql.expressions.LibPostalDataLoader | ||
| import org.scalatest.matchers.should.Matchers | ||
|
|
||
| import java.io.File | ||
| import java.nio.file.Files | ||
|
|
||
| class LibPostalDataLoaderTest extends TestBaseScala with Matchers { | ||
|
|
||
| describe("LibPostalDataLoader") { | ||
|
|
||
| describe("isRemotePath") { | ||
| it("should return false for local paths") { | ||
| LibPostalDataLoader.isRemotePath("/tmp/libpostal/") shouldBe false | ||
| } | ||
|
|
||
| it("should return false for relative paths") { | ||
| LibPostalDataLoader.isRemotePath("data/libpostal/") shouldBe false | ||
| } | ||
|
|
||
| it("should return false for file:// URIs") { | ||
| LibPostalDataLoader.isRemotePath("file:///tmp/libpostal/") shouldBe false | ||
| } | ||
|
|
||
| it("should return true for hdfs:// URIs") { | ||
| LibPostalDataLoader.isRemotePath("hdfs:///data/libpostal/") shouldBe true | ||
| } | ||
|
|
||
| it("should return true for hdfs:// URIs with host") { | ||
| LibPostalDataLoader.isRemotePath("hdfs://namenode:9000/data/libpostal/") shouldBe true | ||
| } | ||
|
|
||
| it("should return true for s3a:// URIs") { | ||
| LibPostalDataLoader.isRemotePath("s3a://my-bucket/libpostal/") shouldBe true | ||
| } | ||
|
|
||
| it("should return true for gs:// URIs") { | ||
| LibPostalDataLoader.isRemotePath("gs://my-bucket/libpostal/") shouldBe true | ||
| } | ||
|
|
||
| it("should return true for abfs:// URIs") { | ||
| LibPostalDataLoader.isRemotePath( | ||
| "abfs://container@account.dfs.core.windows.net/libpostal/") shouldBe true | ||
| } | ||
|
|
||
| it("should return true for wasb:// URIs") { | ||
| LibPostalDataLoader.isRemotePath( | ||
| "wasb://container@account.blob.core.windows.net/libpostal/") shouldBe true | ||
| } | ||
|
|
||
| it("should return false for empty string") { | ||
| LibPostalDataLoader.isRemotePath("") shouldBe false | ||
| } | ||
|
|
||
| it("should return false for Windows-like paths") { | ||
| // Single-letter scheme like C: should not be treated as remote | ||
| LibPostalDataLoader.isRemotePath("C:\\libpostal\\data\\") shouldBe false | ||
| } | ||
| } | ||
|
|
||
| describe("resolveDataDir") { | ||
| it("should return local path unchanged") { | ||
| val tempDir = Files.createTempDirectory("sedona-libpostal-test").toFile | ||
| try { | ||
| val result = LibPostalDataLoader.resolveDataDir(tempDir.getAbsolutePath) | ||
| result shouldBe tempDir.getAbsolutePath | ||
| } finally { | ||
| tempDir.delete() | ||
| } | ||
| } | ||
|
|
||
| it("should normalize file: URI to plain local path") { | ||
| val tempDir = Files.createTempDirectory("sedona-libpostal-test").toFile | ||
| try { | ||
| val fileUri = tempDir.toURI.toString | ||
| val result = LibPostalDataLoader.resolveDataDir(fileUri) | ||
| result should not startWith "file:" | ||
| result shouldBe tempDir.getAbsolutePath | ||
| } finally { | ||
| tempDir.delete() | ||
| } | ||
| } | ||
|
|
||
| it("should throw IllegalStateException when remote data not found in SparkFiles") { | ||
| val remoteUri = "hdfs:///data/nonexistent-libpostal-data/" | ||
|
|
||
| val exception = intercept[IllegalStateException] { | ||
| LibPostalDataLoader.resolveDataDir(remoteUri) | ||
| } | ||
| exception.getMessage should include("not found via SparkFiles") | ||
| exception.getMessage should include("sc.addFile") | ||
| exception.getMessage should include("recursive = true") | ||
| } | ||
|
|
||
| // This test simulates the SparkFiles resolution path without actually calling | ||
| // sc.addFile, which would permanently register the remote URI in SparkContext's | ||
| // internal state and cause downstream test failures when the remote endpoint is | ||
| // no longer available. Instead, we place mock data directly in SparkFiles root. | ||
| it("should resolve remote path when data is present in SparkFiles directory") { | ||
| val sparkFilesDir = new File(SparkFiles.getRootDirectory()) | ||
| val mockDataDir = new File(sparkFilesDir, "libpostal-sparkfiles-test") | ||
| try { | ||
| // Create mock libpostal data in the SparkFiles root directory | ||
| val subdirs = Seq("address_parser", "language_classifier", "transliteration") | ||
| for (subdir <- subdirs) { | ||
| val subdirFile = new File(mockDataDir, subdir) | ||
| subdirFile.mkdirs() | ||
| Files.write(new File(subdirFile, "model.dat").toPath, s"data for $subdir".getBytes) | ||
| } | ||
|
|
||
| // resolveDataDir should find the data via SparkFiles.get(basename) | ||
| val remotePath = "s3a://my-bucket/data/libpostal-sparkfiles-test" | ||
| val localPath = LibPostalDataLoader.resolveDataDir(remotePath) | ||
|
|
||
| // Verify the resolved path is local and contains all expected data | ||
| localPath should not startWith "s3a://" | ||
| val localDir = new File(localPath) | ||
| localDir.exists() shouldBe true | ||
| localDir.isDirectory shouldBe true | ||
| localPath should endWith(File.separator) | ||
|
|
||
| for (subdir <- subdirs) { | ||
| val localSubdir = new File(localDir, subdir) | ||
| localSubdir.exists() shouldBe true | ||
| localSubdir.isDirectory shouldBe true | ||
| new File(localSubdir, "model.dat").exists() shouldBe true | ||
| } | ||
| } finally { | ||
| // Clean up the mock data directory | ||
| if (mockDataDir.exists()) { | ||
| mockDataDir.listFiles().foreach { sub => | ||
| sub.listFiles().foreach(_.delete()) | ||
| sub.delete() | ||
| } | ||
| mockDataDir.delete() | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.