diff --git a/lance-spark-3.4_2.12/pom.xml b/lance-spark-3.4_2.12/pom.xml
index caa08b8..66ef7e6 100644
--- a/lance-spark-3.4_2.12/pom.xml
+++ b/lance-spark-3.4_2.12/pom.xml
@@ -31,6 +31,18 @@
lance-spark-base_${scala.compat.version}
${lance-spark.version}
+
+ org.antlr
+ antlr4
+ ${antlr4.version}
+ provided
+
+
+ org.antlr
+ antlr4-runtime
+ ${antlr4.version}
+ provided
+
com.lancedb
lance-spark-base_${scala.compat.version}
@@ -56,6 +68,22 @@
+
+ org.antlr
+ antlr4-maven-plugin
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../lance-spark-base_2.12/src/main/antlr4
+
+
org.codehaus.mojo
build-helper-maven-plugin
@@ -71,6 +99,8 @@
../lance-spark-base_2.12/src/main/java
src/main/java
+ src/main/scala
+ ${project.build.directory}/generated-sources/antlr4
diff --git a/lance-spark-3.4_2.12/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.4_2.12/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala
new file mode 100644
index 0000000..c081a92
--- /dev/null
+++ b/lance-spark-3.4_2.12/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark.extensions
+
+import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser
+import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy
+
+class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
+
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ // parser extensions
+ extensions.injectParser { case (_, parser) => new LanceSparkSqlExtensionsParser(parser) }
+
+ extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_))
+ }
+}
diff --git a/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala b/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala
new file mode 100644
index 0000000..06dce5b
--- /dev/null
+++ b/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types.{DataType, StructType}
+
+class LanceSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface {
+
+ private lazy val astBuilder = new LanceSqlExtensionsAstBuilder(delegate)
+
+ /**
+ * Parse a string to a DataType.
+ */
+ override def parseDataType(sqlText: String): DataType = {
+ delegate.parseDataType(sqlText)
+ }
+
+ /**
+ * Parse a string to a raw DataType without CHAR/VARCHAR replacement.
+ */
+ def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException()
+
+ /**
+ * Parse a string to an Expression.
+ */
+ override def parseExpression(sqlText: String): Expression = {
+ delegate.parseExpression(sqlText)
+ }
+
+ /**
+ * Parse a string to a TableIdentifier.
+ */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = {
+ delegate.parseTableIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a FunctionIdentifier.
+ */
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ delegate.parseFunctionIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a multi-part identifier.
+ */
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ delegate.parseMultipartIdentifier(sqlText)
+ }
+
+ /**
+ * Creates StructType for a given SQL string, which is a comma separated list of field
+ * definitions which will preserve the correct Hive metadata.
+ */
+ override def parseTableSchema(sqlText: String): StructType = {
+ delegate.parseTableSchema(sqlText)
+ }
+
+ /**
+ * Parse a string to a LogicalPlan.
+ */
+ override def parsePlan(sqlText: String): LogicalPlan = {
+ try {
+ delegate.parsePlan(sqlText)
+ } catch {
+ case _: Exception => parse(sqlText)
+ }
+ }
+
+ override def parseQuery(sqlText: String): LogicalPlan = {
+ delegate.parsePlan(sqlText)
+ }
+
+ protected def parse(command: String): LogicalPlan = {
+ val lexer =
+ new LanceSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
+ lexer.removeErrorListeners()
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new LanceSqlExtensionsParser(tokenStream)
+ parser.removeErrorListeners()
+
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan]
+ } catch {
+ case _: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan]
+ }
+ }
+}
+
+/* Copied from Apache Spark's to avoid dependency on Spark Internals */
+class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
+ override def consume(): Unit = wrapped.consume
+
+ override def getSourceName(): String = wrapped.getSourceName
+
+ override def index(): Int = wrapped.index
+
+ override def mark(): Int = wrapped.mark
+
+ override def release(marker: Int): Unit = wrapped.release(marker)
+
+ override def seek(where: Int): Unit = wrapped.seek(where)
+
+ override def size(): Int = wrapped.size
+
+ override def getText(interval: Interval): String = wrapped.getText(interval)
+
+ // scalastyle:off
+ override def LA(i: Int): Int = {
+ val la = wrapped.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+ // scalastyle:on
+}
diff --git a/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala b/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala
new file mode 100644
index 0000000..ce276aa
--- /dev/null
+++ b/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedIdentifier, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, LogicalPlan}
+
+import scala.jdk.CollectionConverters._
+
+class LanceSqlExtensionsAstBuilder(delegate: ParserInterface)
+ extends LanceSqlExtensionsBaseVisitor[AnyRef] {
+
+ override def visitSingleStatement(ctx: LanceSqlExtensionsParser.SingleStatementContext)
+ : LogicalPlan = {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitAddColumnsBackfill(ctx: LanceSqlExtensionsParser.AddColumnsBackfillContext)
+ : AddColumnsBackfill = {
+ val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier()))
+ val columnNames = visitColumnList(ctx.columnList())
+ val source = UnresolvedRelation(Seq(ctx.identifier().getText))
+ AddColumnsBackfill(table, columnNames, source)
+ }
+
+ override def visitMultipartIdentifier(ctx: LanceSqlExtensionsParser.MultipartIdentifierContext)
+ : Seq[String] = {
+ ctx.parts.asScala.map(_.getText).toSeq
+ }
+
+ /**
+ * Visit identifier list.
+ */
+ override def visitColumnList(ctx: LanceSqlExtensionsParser.ColumnListContext): Seq[String] = {
+ ctx.columns.asScala.map(_.getText).toSeq
+ }
+}
diff --git a/lance-spark-3.4_2.12/src/test/java/com/lancedb/lance/spark/update/AddColumnsBackfillTest.java b/lance-spark-3.4_2.12/src/test/java/com/lancedb/lance/spark/update/AddColumnsBackfillTest.java
new file mode 100644
index 0000000..18fd6cb
--- /dev/null
+++ b/lance-spark-3.4_2.12/src/test/java/com/lancedb/lance/spark/update/AddColumnsBackfillTest.java
@@ -0,0 +1,18 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark.update;
+
+public class AddColumnsBackfillTest extends BaseAddColumnsBackfillTest {
+ // All test methods are inherited from BaseAddColumnsBackfillTest
+}
diff --git a/lance-spark-3.4_2.13/pom.xml b/lance-spark-3.4_2.13/pom.xml
index 6f28832..83a0c5a 100644
--- a/lance-spark-3.4_2.13/pom.xml
+++ b/lance-spark-3.4_2.13/pom.xml
@@ -31,6 +31,18 @@
lance-spark-base_${scala.compat.version}
${lance-spark.version}
+
+ org.antlr
+ antlr4
+ ${antlr4.version}
+ provided
+
+
+ org.antlr
+ antlr4-runtime
+ ${antlr4.version}
+ provided
+
com.lancedb
lance-spark-base_${scala.compat.version}
@@ -61,6 +73,22 @@
+
+ org.antlr
+ antlr4-maven-plugin
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../lance-spark-base_2.12/src/main/antlr4
+
+
org.codehaus.mojo
build-helper-maven-plugin
@@ -76,6 +104,8 @@
../lance-spark-base_2.12/src/main/java
../lance-spark-3.4_2.12/src/main/java
+ ../lance-spark-3.4_2.12/src/main/scala
+ ${project.build.directory}/generated-sources/antlr4
diff --git a/lance-spark-3.5_2.12/pom.xml b/lance-spark-3.5_2.12/pom.xml
index 40ed184..abbf072 100644
--- a/lance-spark-3.5_2.12/pom.xml
+++ b/lance-spark-3.5_2.12/pom.xml
@@ -31,6 +31,18 @@
lance-spark-base_${scala.compat.version}
${lance-spark.version}
+
+ org.antlr
+ antlr4
+ ${antlr4.version}
+ provided
+
+
+ org.antlr
+ antlr4-runtime
+ ${antlr4.version}
+ provided
+
com.lancedb
lance-spark-base_${scala.compat.version}
@@ -56,6 +68,22 @@
+
+ org.antlr
+ antlr4-maven-plugin
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../lance-spark-base_2.12/src/main/antlr4
+
+
org.codehaus.mojo
build-helper-maven-plugin
@@ -72,6 +100,8 @@
../lance-spark-base_2.12/src/main/java
src/main/java
src/main/java11
+ src/main/scala
+ ${project.build.directory}/generated-sources/antlr4
@@ -157,4 +187,4 @@
-
+
\ No newline at end of file
diff --git a/lance-spark-3.5_2.12/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.5_2.12/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala
new file mode 100644
index 0000000..c081a92
--- /dev/null
+++ b/lance-spark-3.5_2.12/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark.extensions
+
+import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser
+import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy
+
+class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
+
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ // parser extensions
+ extensions.injectParser { case (_, parser) => new LanceSparkSqlExtensionsParser(parser) }
+
+ extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_))
+ }
+}
diff --git a/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala b/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala
new file mode 100644
index 0000000..06dce5b
--- /dev/null
+++ b/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types.{DataType, StructType}
+
+class LanceSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface {
+
+ private lazy val astBuilder = new LanceSqlExtensionsAstBuilder(delegate)
+
+ /**
+ * Parse a string to a DataType.
+ */
+ override def parseDataType(sqlText: String): DataType = {
+ delegate.parseDataType(sqlText)
+ }
+
+ /**
+ * Parse a string to a raw DataType without CHAR/VARCHAR replacement.
+ */
+ def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException()
+
+ /**
+ * Parse a string to an Expression.
+ */
+ override def parseExpression(sqlText: String): Expression = {
+ delegate.parseExpression(sqlText)
+ }
+
+ /**
+ * Parse a string to a TableIdentifier.
+ */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = {
+ delegate.parseTableIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a FunctionIdentifier.
+ */
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ delegate.parseFunctionIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a multi-part identifier.
+ */
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ delegate.parseMultipartIdentifier(sqlText)
+ }
+
+ /**
+ * Creates StructType for a given SQL string, which is a comma separated list of field
+ * definitions which will preserve the correct Hive metadata.
+ */
+ override def parseTableSchema(sqlText: String): StructType = {
+ delegate.parseTableSchema(sqlText)
+ }
+
+ /**
+ * Parse a string to a LogicalPlan.
+ */
+ override def parsePlan(sqlText: String): LogicalPlan = {
+ try {
+ delegate.parsePlan(sqlText)
+ } catch {
+ case _: Exception => parse(sqlText)
+ }
+ }
+
+ override def parseQuery(sqlText: String): LogicalPlan = {
+ delegate.parsePlan(sqlText)
+ }
+
+ protected def parse(command: String): LogicalPlan = {
+ val lexer =
+ new LanceSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
+ lexer.removeErrorListeners()
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new LanceSqlExtensionsParser(tokenStream)
+ parser.removeErrorListeners()
+
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan]
+ } catch {
+ case _: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan]
+ }
+ }
+}
+
+/* Copied from Apache Spark's to avoid dependency on Spark Internals */
+class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
+ override def consume(): Unit = wrapped.consume
+
+ override def getSourceName(): String = wrapped.getSourceName
+
+ override def index(): Int = wrapped.index
+
+ override def mark(): Int = wrapped.mark
+
+ override def release(marker: Int): Unit = wrapped.release(marker)
+
+ override def seek(where: Int): Unit = wrapped.seek(where)
+
+ override def size(): Int = wrapped.size
+
+ override def getText(interval: Interval): String = wrapped.getText(interval)
+
+ // scalastyle:off
+ override def LA(i: Int): Int = {
+ val la = wrapped.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+ // scalastyle:on
+}
diff --git a/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala b/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala
new file mode 100644
index 0000000..ce276aa
--- /dev/null
+++ b/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedIdentifier, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, LogicalPlan}
+
+import scala.jdk.CollectionConverters._
+
+class LanceSqlExtensionsAstBuilder(delegate: ParserInterface)
+ extends LanceSqlExtensionsBaseVisitor[AnyRef] {
+
+ override def visitSingleStatement(ctx: LanceSqlExtensionsParser.SingleStatementContext)
+ : LogicalPlan = {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitAddColumnsBackfill(ctx: LanceSqlExtensionsParser.AddColumnsBackfillContext)
+ : AddColumnsBackfill = {
+ val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier()))
+ val columnNames = visitColumnList(ctx.columnList())
+ val source = UnresolvedRelation(Seq(ctx.identifier().getText))
+ AddColumnsBackfill(table, columnNames, source)
+ }
+
+ override def visitMultipartIdentifier(ctx: LanceSqlExtensionsParser.MultipartIdentifierContext)
+ : Seq[String] = {
+ ctx.parts.asScala.map(_.getText).toSeq
+ }
+
+ /**
+ * Visit identifier list.
+ */
+ override def visitColumnList(ctx: LanceSqlExtensionsParser.ColumnListContext): Seq[String] = {
+ ctx.columns.asScala.map(_.getText).toSeq
+ }
+}
diff --git a/lance-spark-3.5_2.12/src/test/java/com/lancedb/lance/spark/update/AddColumnsBackfillTest.java b/lance-spark-3.5_2.12/src/test/java/com/lancedb/lance/spark/update/AddColumnsBackfillTest.java
new file mode 100644
index 0000000..2906e3e
--- /dev/null
+++ b/lance-spark-3.5_2.12/src/test/java/com/lancedb/lance/spark/update/AddColumnsBackfillTest.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark.update;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class AddColumnsBackfillTest extends BaseAddColumnsBackfillTest {
+ // All test methods are inherited from BaseAddColumnsBackfillTest
+
+ @Test
+ public void testWithDeletedRecords() {
+ prepareDataset();
+
+ // Delete some rows
+ spark.sql(String.format("delete from %s where id in (0, 1, 4, 8, 9);", fullTable));
+
+ spark.sql(
+ String.format(
+ "create temporary view tmp_view as select _rowaddr, _fragid, id * 100 as new_col1, id * 2 as new_col2, id * 3 as new_col3 from %s;",
+ fullTable));
+ spark.sql(
+ String.format("alter table %s add columns new_col1, new_col2 from tmp_view", fullTable));
+
+ Assertions.assertEquals(
+ "[[2,200,4,text_2], [3,300,6,text_3], [5,500,10,text_5], [6,600,12,text_6], [7,700,14,text_7]]",
+ spark
+ .sql(String.format("select id, new_col1, new_col2, text from %s", fullTable))
+ .collectAsList()
+ .toString());
+ }
+}
diff --git a/lance-spark-3.5_2.13/pom.xml b/lance-spark-3.5_2.13/pom.xml
index 900e514..b6d8ac1 100644
--- a/lance-spark-3.5_2.13/pom.xml
+++ b/lance-spark-3.5_2.13/pom.xml
@@ -31,6 +31,18 @@
lance-spark-base_${scala.compat.version}
${lance-spark.version}
+
+ org.antlr
+ antlr4
+ ${antlr4.version}
+ provided
+
+
+ org.antlr
+ antlr4-runtime
+ ${antlr4.version}
+ provided
+
com.lancedb
lance-spark-base_${scala.compat.version}
@@ -61,6 +73,22 @@
+
+ org.antlr
+ antlr4-maven-plugin
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../lance-spark-base_2.12/src/main/antlr4
+
+
org.codehaus.mojo
build-helper-maven-plugin
@@ -76,6 +104,8 @@
../lance-spark-base_2.12/src/main/java
../lance-spark-3.5_2.12/src/main/java
+ ../lance-spark-3.5_2.12/src/main/scala
+ ${project.build.directory}/generated-sources/antlr4
diff --git a/lance-spark-4.0_2.13/pom.xml b/lance-spark-4.0_2.13/pom.xml
index e77f76e..2f8a576 100644
--- a/lance-spark-4.0_2.13/pom.xml
+++ b/lance-spark-4.0_2.13/pom.xml
@@ -17,6 +17,7 @@
${scala213.version}
${scala213.compat.version}
+ 4.13.1
@@ -31,6 +32,18 @@
lance-spark-base_${scala.compat.version}
${lance-spark.version}
+
+ org.antlr
+ antlr4
+ ${antlr4.version}
+ provided
+
+
+ org.antlr
+ antlr4-runtime
+ ${antlr4.version}
+ provided
+
com.lancedb
lance-spark-base_${scala.compat.version}
@@ -61,6 +74,22 @@
+
+ org.antlr
+ antlr4-maven-plugin
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../lance-spark-base_2.12/src/main/antlr4
+
+
org.codehaus.mojo
build-helper-maven-plugin
@@ -77,6 +106,8 @@
../lance-spark-base_2.12/src/main/java
../lance-spark-3.5_2.12/src/main/java
src/main/java
+ src/main/scala
+ ${project.build.directory}/generated-sources/antlr4
diff --git a/lance-spark-4.0_2.13/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-4.0_2.13/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala
new file mode 100644
index 0000000..c081a92
--- /dev/null
+++ b/lance-spark-4.0_2.13/src/main/scala/com/lancedb/lance/spark/extensions/LanceSparkSessionExtensions.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark.extensions
+
+import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser
+import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy
+
+class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
+
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ // parser extensions
+ extensions.injectParser { case (_, parser) => new LanceSparkSqlExtensionsParser(parser) }
+
+ extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_))
+ }
+}
diff --git a/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala b/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala
new file mode 100644
index 0000000..5fc1698
--- /dev/null
+++ b/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSparkSqlExtensionsParser.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types.{DataType, StructType}
+
+class LanceSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface {
+
+ private lazy val astBuilder = new LanceSqlExtensionsAstBuilder(delegate)
+
+ /**
+ * Parse a string to a DataType.
+ */
+ override def parseDataType(sqlText: String): DataType = {
+ delegate.parseDataType(sqlText)
+ }
+
+ /**
+ * Parse a string to a raw DataType without CHAR/VARCHAR replacement.
+ */
+ def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException()
+
+ /**
+ * Parse a string to an Expression.
+ */
+ override def parseExpression(sqlText: String): Expression = {
+ delegate.parseExpression(sqlText)
+ }
+
+ /**
+ * Parse a string to a TableIdentifier.
+ */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = {
+ delegate.parseTableIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a FunctionIdentifier.
+ */
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ delegate.parseFunctionIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a multi-part identifier.
+ */
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ delegate.parseMultipartIdentifier(sqlText)
+ }
+
+ /**
+ * Creates StructType for a given SQL string, which is a comma separated list of field
+ * definitions which will preserve the correct Hive metadata.
+ */
+ override def parseTableSchema(sqlText: String): StructType = {
+ delegate.parseTableSchema(sqlText)
+ }
+
+ /**
+ * Parse a string to a LogicalPlan.
+ */
+ override def parsePlan(sqlText: String): LogicalPlan = {
+ try {
+ delegate.parsePlan(sqlText)
+ } catch {
+ case _: Exception => parse(sqlText)
+ }
+ }
+
+ override def parseQuery(sqlText: String): LogicalPlan = {
+ delegate.parsePlan(sqlText)
+ }
+
+ override def parseRoutineParam(sqlText: String): StructType =
+ throw new UnsupportedOperationException()
+
+ protected def parse(command: String): LogicalPlan = {
+ val lexer =
+ new LanceSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
+ lexer.removeErrorListeners()
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new LanceSqlExtensionsParser(tokenStream)
+ parser.removeErrorListeners()
+
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan]
+ } catch {
+ case _: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan]
+ }
+ }
+}
+
+/* Copied from Apache Spark's to avoid dependency on Spark Internals */
+class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
+ override def consume(): Unit = wrapped.consume
+
+ override def getSourceName(): String = wrapped.getSourceName
+
+ override def index(): Int = wrapped.index
+
+ override def mark(): Int = wrapped.mark
+
+ override def release(marker: Int): Unit = wrapped.release(marker)
+
+ override def seek(where: Int): Unit = wrapped.seek(where)
+
+ override def size(): Int = wrapped.size
+
+ override def getText(interval: Interval): String = wrapped.getText(interval)
+
+ // scalastyle:off
+ override def LA(i: Int): Int = {
+ val la = wrapped.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+ // scalastyle:on
+}
diff --git a/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala b/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala
new file mode 100644
index 0000000..ce276aa
--- /dev/null
+++ b/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedIdentifier, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, LogicalPlan}
+
+import scala.jdk.CollectionConverters._
+
+class LanceSqlExtensionsAstBuilder(delegate: ParserInterface)
+ extends LanceSqlExtensionsBaseVisitor[AnyRef] {
+
+ override def visitSingleStatement(ctx: LanceSqlExtensionsParser.SingleStatementContext)
+ : LogicalPlan = {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitAddColumnsBackfill(ctx: LanceSqlExtensionsParser.AddColumnsBackfillContext)
+ : AddColumnsBackfill = {
+ val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier()))
+ val columnNames = visitColumnList(ctx.columnList())
+ val source = UnresolvedRelation(Seq(ctx.identifier().getText))
+ AddColumnsBackfill(table, columnNames, source)
+ }
+
+ override def visitMultipartIdentifier(ctx: LanceSqlExtensionsParser.MultipartIdentifierContext)
+ : Seq[String] = {
+ ctx.parts.asScala.map(_.getText).toSeq
+ }
+
+ /**
+ * Visit identifier list.
+ */
+ override def visitColumnList(ctx: LanceSqlExtensionsParser.ColumnListContext): Seq[String] = {
+ ctx.columns.asScala.map(_.getText).toSeq
+ }
+}
diff --git a/lance-spark-base_2.12/src/main/antlr4/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensions.g4 b/lance-spark-base_2.12/src/main/antlr4/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensions.g4
new file mode 100644
index 0000000..83a671b
--- /dev/null
+++ b/lance-spark-base_2.12/src/main/antlr4/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensions.g4
@@ -0,0 +1,66 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+grammar LanceSqlExtensions;
+
+singleStatement
+ : statement EOF
+ ;
+
+statement
+ : ALTER TABLE multipartIdentifier ADD COLUMNS columnList FROM identifier #addColumnsBackfill
+ ;
+
+multipartIdentifier
+ : parts+=identifier ('.' parts+=identifier)*
+ ;
+
+identifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+columnList
+ : columns+=identifier (',' columns+=identifier)*
+ ;
+
+
+ADD: 'ADD';
+ALTER: 'ALTER';
+COLUMNS: 'COLUMNS';
+FROM: 'FROM';
+TABLE: 'TABLE';
+
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+
diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/LanceConstant.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/LanceConstant.java
index 74e0e2d..9035814 100644
--- a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/LanceConstant.java
+++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/LanceConstant.java
@@ -21,4 +21,6 @@ public class LanceConstant {
// Blob metadata column suffixes
public static final String BLOB_POSITION_SUFFIX = "__blob_pos";
public static final String BLOB_SIZE_SUFFIX = "__blob_size";
+
+ public static final String BACKFILL_COLUMNS_KEY = "backfill_columns";
}
diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/LanceDataset.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/LanceDataset.java
index 4d5da0f..7c38a50 100644
--- a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/LanceDataset.java
+++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/LanceDataset.java
@@ -15,6 +15,7 @@
import com.lancedb.lance.spark.read.LanceScanBuilder;
import com.lancedb.lance.spark.utils.BlobUtils;
+import com.lancedb.lance.spark.write.AddColumnsBackfillWrite;
import com.lancedb.lance.spark.write.SparkWrite;
import com.google.common.collect.ImmutableSet;
@@ -33,8 +34,10 @@
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.Set;
+import java.util.stream.Collectors;
/** Lance Spark Dataset. */
public class LanceDataset implements SupportsRead, SupportsWrite, SupportsMetadataColumns {
@@ -42,53 +45,58 @@ public class LanceDataset implements SupportsRead, SupportsWrite, SupportsMetada
ImmutableSet.of(
TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE);
- public static final MetadataColumn[] METADATA_COLUMNS =
- new MetadataColumn[] {
- new MetadataColumn() {
- @Override
- public String name() {
- return LanceConstant.FRAGMENT_ID;
- }
-
- @Override
- public DataType dataType() {
- return DataTypes.IntegerType;
- }
-
- @Override
- public boolean isNullable() {
- return false;
- }
- },
- new MetadataColumn() {
- @Override
- public String name() {
- return LanceConstant.ROW_ID;
- }
-
- @Override
- public DataType dataType() {
- return DataTypes.LongType;
- }
- },
- new MetadataColumn() {
- @Override
- public String name() {
- return LanceConstant.ROW_ADDRESS;
- }
-
- @Override
- public DataType dataType() {
- return DataTypes.LongType;
- }
-
- @Override
- public boolean isNullable() {
- return false;
- }
- },
+ public static final MetadataColumn FRAGMENT_ID_COLUMN =
+ new MetadataColumn() {
+ @Override
+ public String name() {
+ return LanceConstant.FRAGMENT_ID;
+ }
+
+ @Override
+ public DataType dataType() {
+ return DataTypes.IntegerType;
+ }
+
+ @Override
+ public boolean isNullable() {
+ return false;
+ }
+ };
+
+ public static final MetadataColumn ROW_ID_COLUMN =
+ new MetadataColumn() {
+ @Override
+ public String name() {
+ return LanceConstant.ROW_ID;
+ }
+
+ @Override
+ public DataType dataType() {
+ return DataTypes.LongType;
+ }
+ };
+
+ public static final MetadataColumn ROW_ADDRESS_COLUMN =
+ new MetadataColumn() {
+ @Override
+ public String name() {
+ return LanceConstant.ROW_ADDRESS;
+ }
+
+ @Override
+ public DataType dataType() {
+ return DataTypes.LongType;
+ }
+
+ @Override
+ public boolean isNullable() {
+ return false;
+ }
};
+ public static final MetadataColumn[] METADATA_COLUMNS =
+ new MetadataColumn[] {ROW_ID_COLUMN, ROW_ADDRESS_COLUMN, FRAGMENT_ID_COLUMN};
+
LanceConfig config;
protected final StructType sparkSchema;
@@ -129,6 +137,20 @@ public Set capabilities() {
@Override
public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) {
+ List backfillColumns =
+ Arrays.stream(
+ logicalWriteInfo
+ .options()
+ .getOrDefault(LanceConstant.BACKFILL_COLUMNS_KEY, "")
+ .split(","))
+ .map(String::trim)
+ .filter(t -> !t.isEmpty())
+ .collect(Collectors.toList());
+ if (!backfillColumns.isEmpty()) {
+ return new AddColumnsBackfillWrite.AddColumnsWriteBuilder(
+ sparkSchema, config, backfillColumns);
+ }
+
return new SparkWrite.SparkWriteBuilder(sparkSchema, config);
}
diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
index d920f8c..576ac8c 100644
--- a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
+++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
@@ -13,7 +13,14 @@
*/
package com.lancedb.lance.spark.internal;
-import com.lancedb.lance.*;
+import com.lancedb.lance.Dataset;
+import com.lancedb.lance.Fragment;
+import com.lancedb.lance.FragmentMetadata;
+import com.lancedb.lance.FragmentOperation;
+import com.lancedb.lance.ReadOptions;
+import com.lancedb.lance.WriteParams;
+import com.lancedb.lance.fragment.FragmentMergeResult;
+import com.lancedb.lance.operation.Merge;
import com.lancedb.lance.operation.Update;
import com.lancedb.lance.spark.LanceConfig;
import com.lancedb.lance.spark.SparkOptions;
@@ -144,6 +151,29 @@ public static void updateFragments(
}
}
+ public static void mergeFragments(
+ LanceConfig config, List fragments, Schema schema) {
+ String uri = config.getDatasetUri();
+ ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
+ try (Dataset dataset = Dataset.open(allocator, uri, options)) {
+ dataset
+ .newTransactionBuilder()
+ .operation(Merge.builder().fragments(fragments).schema(schema).build())
+ .build()
+ .commit();
+ }
+ }
+
+ public static FragmentMergeResult mergeFragmentColumn(
+ LanceConfig config, int fragmentId, ArrowArrayStream stream, String leftOn, String rightOn) {
+ String uri = config.getDatasetUri();
+ ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
+ try (Dataset dataset = Dataset.open(allocator, uri, options)) {
+ Fragment fragment = dataset.getFragment(fragmentId);
+ return fragment.mergeColumns(stream, leftOn, rightOn);
+ }
+ }
+
public static FragmentMetadata deleteRows(
LanceConfig config, int fragmentId, List rowIndexes) {
String uri = config.getDatasetUri();
diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/AddColumnsBackfillBatchWrite.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/AddColumnsBackfillBatchWrite.java
new file mode 100644
index 0000000..4206b70
--- /dev/null
+++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/AddColumnsBackfillBatchWrite.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark.write;
+
+import com.lancedb.lance.FragmentMetadata;
+import com.lancedb.lance.fragment.FragmentMergeResult;
+import com.lancedb.lance.spark.LanceConfig;
+import com.lancedb.lance.spark.LanceDataset;
+import com.lancedb.lance.spark.internal.LanceDatasetAdapter;
+
+import org.apache.arrow.c.ArrowArrayStream;
+import org.apache.arrow.c.Data;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowStreamReader;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.write.BatchWrite;
+import org.apache.spark.sql.connector.write.DataWriter;
+import org.apache.spark.sql.connector.write.DataWriterFactory;
+import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
+import org.apache.spark.sql.connector.write.WriterCommitMessage;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.LanceArrowUtils;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class AddColumnsBackfillBatchWrite implements BatchWrite {
+ private final StructType schema;
+ private final LanceConfig config;
+ private final List newColumns;
+
+ public AddColumnsBackfillBatchWrite(
+ StructType schema, LanceConfig config, List newColumns) {
+ this.schema = schema;
+ this.config = config;
+ this.newColumns = newColumns;
+ }
+
+ @Override
+ public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
+ return new AddColumnsWriterFactory(schema, config, newColumns);
+ }
+
+ @Override
+ public boolean useCommitCoordinator() {
+ return false;
+ }
+
+ @Override
+ public void commit(WriterCommitMessage[] messages) {
+ List fragments =
+ Arrays.stream(messages)
+ .map(m -> (TaskCommit) m)
+ .map(TaskCommit::getFragments)
+ .flatMap(List::stream)
+ .collect(Collectors.toList());
+
+ StructType sparkSchema = ((TaskCommit) messages[0]).schema;
+ Schema schema = LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false);
+ LanceDatasetAdapter.mergeFragments(config, fragments, schema);
+ }
+
+ public static class AddColumnsWriter implements DataWriter {
+ private final LanceConfig config;
+ private final StructType schema;
+ private final int fragmentIdField;
+ private final List fragments;
+
+ private Schema mergedSchema;
+ private StructType writerSchema;
+ private int fragmentId = -1;
+ private VectorSchemaRoot data;
+ private com.lancedb.lance.spark.arrow.LanceArrowWriter writer = null;
+
+ public AddColumnsWriter(LanceConfig config, StructType schema, List newColumns) {
+ this.config = config;
+ this.schema = schema;
+ this.fragmentIdField = schema.fieldIndex(LanceDataset.FRAGMENT_ID_COLUMN.name());
+ this.fragments = new ArrayList<>();
+
+ this.writerSchema = new StructType();
+ Arrays.stream(schema.fields())
+ .filter(
+ f ->
+ newColumns.contains(f.name())
+ || f.name().equals(LanceDataset.ROW_ADDRESS_COLUMN.name()))
+ .forEach(f -> writerSchema = writerSchema.add(f));
+
+ createWriter();
+ }
+
+ @Override
+ public void write(InternalRow record) throws IOException {
+ int fragId = record.getInt(fragmentIdField);
+
+ if (fragmentId == -1) {
+ fragmentId = fragId;
+ }
+
+ if (fragId != fragmentId && data != null) {
+ // New fragment's data is coming, close the current fragment's writer.
+ mergeFragment();
+
+ fragmentId = fragId;
+ createWriter();
+ }
+
+ for (int i = 0; i < writerSchema.fields().length; i++) {
+ writer.field(i).write(record, schema.fieldIndex(writerSchema.fields()[i].name()));
+ }
+ }
+
+ private void createWriter() {
+ data =
+ VectorSchemaRoot.create(
+ LanceArrowUtils.toArrowSchema(writerSchema, "UTC", false, false),
+ LanceDatasetAdapter.allocator);
+
+ writer = com.lancedb.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(data, writerSchema);
+ }
+
+ private void mergeFragment() {
+ writer.finish();
+
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ try (ArrowStreamWriter writer = new ArrowStreamWriter(data, null, out)) {
+ writer.start();
+ writer.writeBatch();
+ writer.end();
+ } catch (IOException e) {
+ throw new RuntimeException("Cannot write schema root", e);
+ }
+
+ byte[] arrowData = out.toByteArray();
+ ByteArrayInputStream in = new ByteArrayInputStream(arrowData);
+
+ try (ArrowStreamReader reader = new ArrowStreamReader(in, LanceDatasetAdapter.allocator);
+ ArrowArrayStream stream = ArrowArrayStream.allocateNew(LanceDatasetAdapter.allocator)) {
+ Data.exportArrayStream(LanceDatasetAdapter.allocator, reader, stream);
+
+ FragmentMergeResult result =
+ LanceDatasetAdapter.mergeFragmentColumn(
+ config,
+ fragmentId,
+ stream,
+ LanceDataset.ROW_ADDRESS_COLUMN.name(),
+ LanceDataset.ROW_ADDRESS_COLUMN.name());
+
+ fragments.add(result.getFragmentMetadata());
+ mergedSchema = result.getSchema().asArrowSchema();
+ } catch (Exception e) {
+ throw new RuntimeException("Cannot read arrow stream.", e);
+ }
+
+ data.close();
+ }
+
+ @Override
+ public WriterCommitMessage commit() {
+ if (fragmentId >= 0 && data != null) {
+ mergeFragment();
+ }
+ return new TaskCommit(fragments, LanceArrowUtils.fromArrowSchema(mergedSchema));
+ }
+
+ @Override
+ public void abort() {}
+
+ @Override
+ public void close() throws IOException {}
+ }
+
+ public static class AddColumnsWriterFactory implements DataWriterFactory {
+ private final LanceConfig config;
+ private final StructType schema;
+ private final List newColumns;
+
+ protected AddColumnsWriterFactory(
+ StructType schema, LanceConfig config, List newColumns) {
+ // Everything passed to writer factory should be serializable
+ this.schema = schema;
+ this.config = config;
+ this.newColumns = newColumns;
+ }
+
+ @Override
+ public DataWriter createWriter(int partitionId, long taskId) {
+ return new AddColumnsWriter(config, schema, newColumns);
+ }
+ }
+
+ @Override
+ public void abort(WriterCommitMessage[] messages) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String toString() {
+ return String.format("AddColumnsWriterFactory(datasetUri=%s)", config.getDatasetUri());
+ }
+
+ public static class TaskCommit implements WriterCommitMessage {
+ private final List fragments;
+ private final StructType schema;
+
+ TaskCommit(List fragments, StructType schema) {
+ this.fragments = fragments;
+ this.schema = schema;
+ }
+
+ List getFragments() {
+ return fragments;
+ }
+ }
+}
diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/AddColumnsBackfillWrite.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/AddColumnsBackfillWrite.java
new file mode 100644
index 0000000..446dc15
--- /dev/null
+++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/AddColumnsBackfillWrite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark.write;
+
+import com.lancedb.lance.spark.LanceConfig;
+import com.lancedb.lance.spark.LanceConstant;
+
+import org.apache.spark.sql.connector.distributions.Distribution;
+import org.apache.spark.sql.connector.distributions.Distributions;
+import org.apache.spark.sql.connector.expressions.Expressions;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.NullOrdering;
+import org.apache.spark.sql.connector.expressions.SortDirection;
+import org.apache.spark.sql.connector.expressions.SortOrder;
+import org.apache.spark.sql.connector.expressions.SortValue;
+import org.apache.spark.sql.connector.write.BatchWrite;
+import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering;
+import org.apache.spark.sql.connector.write.Write;
+import org.apache.spark.sql.connector.write.WriteBuilder;
+import org.apache.spark.sql.connector.write.streaming.StreamingWrite;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.List;
+
+/** Spark write builder. */
+public class AddColumnsBackfillWrite implements Write, RequiresDistributionAndOrdering {
+ private final LanceConfig config;
+ private final StructType schema;
+ private final List newColumns;
+
+ AddColumnsBackfillWrite(StructType schema, LanceConfig config, List newColumns) {
+ this.schema = schema;
+ this.config = config;
+ this.newColumns = newColumns;
+ }
+
+ @Override
+ public BatchWrite toBatch() {
+ return new AddColumnsBackfillBatchWrite(schema, config, newColumns);
+ }
+
+ @Override
+ public StreamingWrite toStreaming() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Distribution requiredDistribution() {
+ NamedReference segmentId = Expressions.column(LanceConstant.FRAGMENT_ID);
+ return Distributions.clustered(new NamedReference[] {segmentId});
+ }
+
+ @Override
+ public SortOrder[] requiredOrdering() {
+ NamedReference segmentId = Expressions.column(LanceConstant.ROW_ADDRESS);
+ SortValue sortValue =
+ new SortValue(segmentId, SortDirection.ASCENDING, NullOrdering.NULLS_FIRST);
+ return new SortValue[] {sortValue};
+ }
+
+ /** Task commit. */
+ public static class AddColumnsWriteBuilder implements WriteBuilder {
+ private final LanceConfig config;
+ private final StructType schema;
+ private final List newColumns;
+
+ public AddColumnsWriteBuilder(StructType schema, LanceConfig config, List newColumns) {
+ this.schema = schema;
+ this.config = config;
+ this.newColumns = newColumns;
+ }
+
+ @Override
+ public Write build() {
+ return new AddColumnsBackfillWrite(schema, config, newColumns);
+ }
+ }
+}
diff --git a/lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala
index cca047c..04c5e57 100644
--- a/lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala
+++ b/lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala
@@ -81,6 +81,7 @@ object LanceArrowWriter {
case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
case (LongType, vector: BigIntVector) => new LongWriter(vector)
+ case (LongType, vector: UInt8Vector) => new UnsignedLongWriter(vector)
case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
case (dt: DecimalType, vector: DecimalVector) =>
@@ -150,6 +151,8 @@ class LanceArrowWriter(root: VectorSchemaRoot, fields: Array[LanceArrowFieldWrit
fields.foreach(_.reset())
root.setRowCount(0)
}
+
+ def field(index: Int): LanceArrowFieldWriter = fields(index)
}
/**
@@ -233,6 +236,14 @@ private[arrow] class LongWriter(val valueVector: BigIntVector) extends LanceArro
}
}
+private[arrow] class UnsignedLongWriter(val valueVector: UInt8Vector)
+ extends LanceArrowFieldWriter {
+ override def setNull(): Unit = {}
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getLong(ordinal))
+ }
+}
+
private[arrow] class FloatWriter(val valueVector: Float4Vector) extends LanceArrowFieldWriter {
override def setNull(): Unit = {}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddColumnsBackfill.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddColumnsBackfill.scala
new file mode 100755
index 0000000..8ec6e87
--- /dev/null
+++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddColumnsBackfill.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+/**
+ * Logical plan node representing the ALTER TABLE ADD COLUMNS FROM TABLE/VIEW command.
+ *
+ * This command adds new columns to an existing table by computing their values
+ * from a TABLE/VIEW. The TABLE/VIEW is executed and the results are used
+ * to populate the new columns.
+ *
+ * @param table The target table to add columns to
+ * @param columnNames The names of the new columns to add
+ * @param source The TABLE/VIEW that provides the values for the new columns
+ */
+case class AddColumnsBackfill(
+ table: LogicalPlan,
+ columnNames: Seq[String],
+ source: LogicalPlan) extends Command {
+
+ override def children: Seq[LogicalPlan] = Seq(table, source)
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[LogicalPlan]): AddColumnsBackfill = {
+ copy(table = newChildren(0), columnNames, source = newChildren(1))
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ s"AddColumnsBackfill columns=[${columnNames.mkString(", ")}]"
+ }
+}
diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddColumnsBackfillExec.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddColumnsBackfillExec.scala
new file mode 100644
index 0000000..277a74e
--- /dev/null
+++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddColumnsBackfillExec.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2
+
+import com.lancedb.lance.spark.{LanceConstant, LanceDataset}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Project}
+import org.apache.spark.sql.connector.catalog._
+
+case class AddColumnsBackfillExec(
+ catalog: TableCatalog,
+ ident: Identifier,
+ columnNames: Seq[String],
+ query: LogicalPlan)
+ extends LeafV2CommandExec {
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override protected def run(): Seq[InternalRow] = {
+ val originalTable = catalog.loadTable(ident) match {
+ case lanceTable: LanceDataset => lanceTable
+ case _ =>
+ throw new UnsupportedOperationException("AddColumnsBackfill only supports for LanceDataset")
+ }
+
+ // Check the added columns must not exist
+ val originalFields = originalTable.schema().fieldNames.toSet
+ val existedFields = columnNames.filter(p => originalFields.contains(p))
+ if (existedFields.nonEmpty) {
+ throw new IllegalArgumentException(
+ s"Can't add existed columns: ${existedFields.toList.mkString(", ")}")
+ }
+
+ // Add Project if source relation has more fields
+ val needFields = query.output.filter(p =>
+ columnNames.contains(p.name)
+ || LanceDataset.ROW_ADDRESS_COLUMN.name().equals(p.name)
+ || LanceDataset.FRAGMENT_ID_COLUMN.name().equals(p.name))
+
+ val actualQuery = if (needFields.length != query.output.length) {
+ Project(needFields, query)
+ } else {
+ query
+ }
+
+ val relation = DataSourceV2Relation.create(
+ new LanceDataset(originalTable.config(), actualQuery.schema),
+ Some(catalog),
+ Some(ident))
+
+ val append =
+ AppendData.byPosition(
+ relation,
+ actualQuery,
+ Map(LanceConstant.BACKFILL_COLUMNS_KEY -> columnNames.mkString(",")))
+ val qe = session.sessionState.executePlan(append)
+ qe.assertCommandExecuted()
+
+ Nil
+ }
+}
diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/LanceDataSourceV2Strategy.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/LanceDataSourceV2Strategy.scala
new file mode 100644
index 0000000..605787e
--- /dev/null
+++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/LanceDataSourceV2Strategy.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier
+import org.apache.spark.sql.catalyst.expressions.PredicateHelper
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.connector.catalog._
+import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
+
+case class LanceDataSourceV2Strategy(session: SparkSession) extends SparkStrategy
+ with PredicateHelper {
+
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case AddColumnsBackfill(ResolvedIdentifier(catalog, ident), columnNames, source) =>
+ AddColumnsBackfillExec(asTableCatalog(catalog), ident, columnNames, source) :: Nil
+
+ case _ => Nil
+ }
+
+ private def asTableCatalog(plugin: CatalogPlugin): TableCatalog = {
+ plugin match {
+ case t: TableCatalog => t
+ case _ => throw new IllegalArgumentException(s"Catalog $plugin is not a TableCatalog")
+ }
+ }
+
+}
diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala
index 0775b50..7643041 100644
--- a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala
+++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala
@@ -130,7 +130,9 @@ object LanceArrowUtils {
}
implicit val formats: Formats = DefaultFormats
- meta = metadata.jsonValue.extract[Map[String, String]]
+ meta = metadata.jsonValue.extract[Map[String, Object]].map { case (k, v) =>
+ (k, String.valueOf(v))
+ }
}
dt match {
@@ -196,7 +198,8 @@ object LanceArrowUtils {
case ByteType => new ArrowType.Int(8, true)
case ShortType => new ArrowType.Int(8 * 2, true)
case IntegerType => new ArrowType.Int(8 * 4, true)
- case LongType if name.equals(LanceConstant.ROW_ID) => new ArrowType.Int(8 * 8, false)
+ case LongType if name.equals(LanceConstant.ROW_ID) || name.equals(LanceConstant.ROW_ADDRESS) =>
+ new ArrowType.Int(8 * 8, false)
case LongType => new ArrowType.Int(8 * 8, true)
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
diff --git a/lance-spark-base_2.12/src/test/java/com/lancedb/lance/spark/update/BaseAddColumnsBackfillTest.java b/lance-spark-base_2.12/src/test/java/com/lancedb/lance/spark/update/BaseAddColumnsBackfillTest.java
new file mode 100644
index 0000000..421f08f
--- /dev/null
+++ b/lance-spark-base_2.12/src/test/java/com/lancedb/lance/spark/update/BaseAddColumnsBackfillTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark.update;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.functions;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public abstract class BaseAddColumnsBackfillTest {
+ protected String catalogName = "lance_test";
+ protected String tableName = "add_column_backfill";
+ protected String fullTable = catalogName + ".default." + tableName;
+
+ protected SparkSession spark;
+
+ @TempDir Path tempDir;
+
+ @BeforeEach
+ public void setup() throws IOException {
+ spark =
+ SparkSession.builder()
+ .appName("dataframe-addcolumn-test")
+ .master("local[2]")
+ .config(
+ "spark.sql.catalog." + catalogName,
+ "com.lancedb.lance.spark.LanceNamespaceSparkCatalog")
+ .config(
+ "spark.sql.extensions",
+ "com.lancedb.lance.spark.extensions.LanceSparkSessionExtensions")
+ .config("spark.sql.catalog." + catalogName + ".impl", "dir")
+ .config("spark.sql.catalog." + catalogName + ".root", tempDir.toString())
+ .getOrCreate();
+ }
+
+ @AfterEach
+ public void tearDown() throws IOException {
+ if (spark != null) {
+ spark.close();
+ }
+ }
+
+ protected void prepareDataset() {
+ spark.sql(String.format("create table %s (id int, text string) using lance;", fullTable));
+ spark.sql(
+ String.format(
+ "insert into %s (id, text) values %s ;",
+ fullTable,
+ IntStream.range(0, 10)
+ .boxed()
+ .map(i -> String.format("(%d, 'text_%d')", i, i))
+ .collect(Collectors.joining(","))));
+ }
+
+ @Test
+ public void testWithDataFrame() {
+ prepareDataset();
+
+ // Read back and verify
+ Dataset result = spark.table(fullTable);
+ assertEquals(10, result.count(), "Should have 10 rows");
+
+ result = result.select("_rowaddr", "_fragid", "id");
+
+ // Add new column
+ Dataset df2 =
+ result
+ .withColumn("new_col1", functions.expr("id * 100"))
+ .withColumn("new_col2", functions.expr("id * 2"));
+
+ df2.createOrReplaceTempView("tmp_view");
+ spark.sql(
+ String.format("alter table %s add columns new_col1, new_col2 from tmp_view", fullTable));
+
+ Assertions.assertEquals(
+ "[[0,0,0,text_0], [1,100,2,text_1], [2,200,4,text_2], [3,300,6,text_3], [4,400,8,text_4], [5,500,10,text_5], [6,600,12,text_6], [7,700,14,text_7], [8,800,16,text_8], [9,900,18,text_9]]",
+ spark
+ .table(fullTable)
+ .select("id", "new_col1", "new_col2", "text")
+ .collectAsList()
+ .toString());
+ }
+
+ @Test
+ public void testWithSql() {
+ prepareDataset();
+
+ spark.sql(
+ String.format(
+ "create temporary view tmp_view as select _rowaddr, _fragid, id * 100 as new_col1, id * 2 as new_col2, id * 3 as new_col3 from %s;",
+ fullTable));
+ spark.sql(
+ String.format("alter table %s add columns new_col1, new_col2 from tmp_view", fullTable));
+
+ Assertions.assertEquals(
+ "[[0,0,0,text_0], [1,100,2,text_1], [2,200,4,text_2], [3,300,6,text_3], [4,400,8,text_4], [5,500,10,text_5], [6,600,12,text_6], [7,700,14,text_7], [8,800,16,text_8], [9,900,18,text_9]]",
+ spark
+ .sql(String.format("select id, new_col1, new_col2, text from %s", fullTable))
+ .collectAsList()
+ .toString());
+ }
+
+ @Test
+ public void testAddExistedColumns() {
+ prepareDataset();
+
+ spark.sql(
+ String.format(
+ "create temporary view tmp_view as select _rowaddr, _fragid, id * 100 as id, id * 2 as new_col2 from %s;",
+ fullTable));
+ Assertions.assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ spark.sql(
+ String.format("alter table %s add columns id, new_col2 from tmp_view", fullTable)),
+ "Can't add existed columns: id");
+ }
+
+ @Test
+ public void testAddRowsNotAligned() {
+ prepareDataset();
+
+ // Add a new String column (which can be null)
+ // New records are not aligned with existing records
+ spark.sql(
+ String.format(
+ "create temporary view tmp_view as select _rowaddr, _fragid, concat('new_col_1_', id) as new_col1 from %s where id in (0, 1, 4, 8, 9);",
+ fullTable));
+ spark.sql(String.format("alter table %s add columns new_col1 from tmp_view", fullTable));
+
+ Assertions.assertEquals(
+ "[[0,new_col_1_0,text_0], [1,new_col_1_1,text_1], [2,null,text_2], [3,null,text_3], [4,new_col_1_4,text_4], [5,null,text_5], [6,null,text_6], [7,null,text_7], [8,new_col_1_8,text_8], [9,new_col_1_9,text_9]]",
+ spark
+ .sql(String.format("select id, new_col1, text from %s", fullTable))
+ .collectAsList()
+ .toString());
+ }
+}
diff --git a/pom.xml b/pom.xml
index f37d5e2..78bcd9a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -51,9 +51,11 @@
0.0.14
- 0.35.0
+ 0.37.0
0.0.19
+ 4.9.3
+
3.4.4
3.5.5
4.0.0
@@ -371,6 +373,11 @@
+
+ org.antlr
+ antlr4-maven-plugin
+ ${antlr4.version}
+
maven-jar-plugin
${maven-jar-plugin.version}