Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions lance-spark-3.4_2.12/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
<version>${lance-spark.version}</version>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4</artifactId>
<version>${antlr4.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
<version>${antlr4.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
Expand All @@ -56,6 +68,22 @@
</testResource>
</testResources>
<plugins>
<plugin>
<groupId>org.antlr</groupId>
<artifactId>antlr4-maven-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>antlr4</goal>
</goals>
</execution>
</executions>
<configuration>
<visitor>true</visitor>
<listener>true</listener>
<sourceDirectory>../lance-spark-base_2.12/src/main/antlr4</sourceDirectory>
</configuration>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
Expand All @@ -71,6 +99,8 @@
<sources>
<source>../lance-spark-base_2.12/src/main/java</source>
<source>src/main/java</source>
<source>src/main/scala</source>
<source>${project.build.directory}/generated-sources/antlr4</source>
</sources>
</configuration>
</execution>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.lancedb.lance.spark.extensions

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser
import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy

class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) {

override def apply(extensions: SparkSessionExtensions): Unit = {
// parser extensions
extensions.injectParser { case (_, parser) => new LanceSparkSqlExtensionsParser(parser) }

extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser.extensions

import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.{DataType, StructType}

class LanceSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface {

private lazy val astBuilder = new LanceSqlExtensionsAstBuilder(delegate)

/**
* Parse a string to a DataType.
*/
override def parseDataType(sqlText: String): DataType = {
delegate.parseDataType(sqlText)
}

/**
* Parse a string to a raw DataType without CHAR/VARCHAR replacement.
*/
def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException()

/**
* Parse a string to an Expression.
*/
override def parseExpression(sqlText: String): Expression = {
delegate.parseExpression(sqlText)
}

/**
* Parse a string to a TableIdentifier.
*/
override def parseTableIdentifier(sqlText: String): TableIdentifier = {
delegate.parseTableIdentifier(sqlText)
}

/**
* Parse a string to a FunctionIdentifier.
*/
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
delegate.parseFunctionIdentifier(sqlText)
}

/**
* Parse a string to a multi-part identifier.
*/
override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
delegate.parseMultipartIdentifier(sqlText)
}

/**
* Creates StructType for a given SQL string, which is a comma separated list of field
* definitions which will preserve the correct Hive metadata.
*/
override def parseTableSchema(sqlText: String): StructType = {
delegate.parseTableSchema(sqlText)
}

/**
* Parse a string to a LogicalPlan.
*/
override def parsePlan(sqlText: String): LogicalPlan = {
try {
delegate.parsePlan(sqlText)
} catch {
case _: Exception => parse(sqlText)
}
}

override def parseQuery(sqlText: String): LogicalPlan = {
delegate.parsePlan(sqlText)
}

protected def parse(command: String): LogicalPlan = {
val lexer =
new LanceSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()

val tokenStream = new CommonTokenStream(lexer)
val parser = new LanceSqlExtensionsParser(tokenStream)
parser.removeErrorListeners()

try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan]
} catch {
case _: ParseCancellationException =>
// if we fail, parse with LL mode
tokenStream.seek(0) // rewind input stream
parser.reset()

// Try Again.
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan]
}
}
}

/* Copied from Apache Spark's to avoid dependency on Spark Internals */
class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
override def consume(): Unit = wrapped.consume

override def getSourceName(): String = wrapped.getSourceName

override def index(): Int = wrapped.index

override def mark(): Int = wrapped.mark

override def release(marker: Int): Unit = wrapped.release(marker)

override def seek(where: Int): Unit = wrapped.seek(where)

override def size(): Int = wrapped.size

override def getText(interval: Interval): String = wrapped.getText(interval)

// scalastyle:off
override def LA(i: Int): Int = {
val la = wrapped.LA(i)
if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
// scalastyle:on
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser.extensions

import org.apache.spark.sql.catalyst.analysis.{UnresolvedIdentifier, UnresolvedRelation}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, LogicalPlan}

import scala.jdk.CollectionConverters._

class LanceSqlExtensionsAstBuilder(delegate: ParserInterface)
extends LanceSqlExtensionsBaseVisitor[AnyRef] {

override def visitSingleStatement(ctx: LanceSqlExtensionsParser.SingleStatementContext)
: LogicalPlan = {
visit(ctx.statement).asInstanceOf[LogicalPlan]
}

override def visitAddColumnsBackfill(ctx: LanceSqlExtensionsParser.AddColumnsBackfillContext)
: AddColumnsBackfill = {
val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier()))
val columnNames = visitColumnList(ctx.columnList())
val source = UnresolvedRelation(Seq(ctx.identifier().getText))
AddColumnsBackfill(table, columnNames, source)
}

override def visitMultipartIdentifier(ctx: LanceSqlExtensionsParser.MultipartIdentifierContext)
: Seq[String] = {
ctx.parts.asScala.map(_.getText).toSeq
}

/**
* Visit identifier list.
*/
override def visitColumnList(ctx: LanceSqlExtensionsParser.ColumnListContext): Seq[String] = {
ctx.columns.asScala.map(_.getText).toSeq
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.lancedb.lance.spark.update;

public class AddColumnsBackfillTest extends BaseAddColumnsBackfillTest {
// All test methods are inherited from BaseAddColumnsBackfillTest
}
30 changes: 30 additions & 0 deletions lance-spark-3.4_2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
<version>${lance-spark.version}</version>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4</artifactId>
<version>${antlr4.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
<version>${antlr4.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
Expand Down Expand Up @@ -61,6 +73,22 @@
</testResource>
</testResources>
<plugins>
<plugin>
<groupId>org.antlr</groupId>
<artifactId>antlr4-maven-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>antlr4</goal>
</goals>
</execution>
</executions>
<configuration>
<visitor>true</visitor>
<listener>true</listener>
<sourceDirectory>../lance-spark-base_2.12/src/main/antlr4</sourceDirectory>
</configuration>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
Expand All @@ -76,6 +104,8 @@
<sources>
<source>../lance-spark-base_2.12/src/main/java</source>
<source>../lance-spark-3.4_2.12/src/main/java</source>
<source>../lance-spark-3.4_2.12/src/main/scala</source>
<source>${project.build.directory}/generated-sources/antlr4</source>
</sources>
</configuration>
</execution>
Expand Down
Loading