Skip to content

Commit

Permalink
[SPARK-49369][CONNECT][SQL] Add implicit Column conversions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This introduces an implicit conversion for the Column companion object that allows a user/developer to create a Column from a catalyst Expression (for Classic) or a proto Expression (Builder) (for Connect). This mostly recreates they had before we refactored the Column API. This comes at the price of adding the an import.

### Why are the changes needed?
Improved upgrade experience for Developers and User who create their own Column's from expressions.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
I added it to a couple of places in the code and it works.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48020 from hvanhovell/SPARK-49369.

Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
hvanhovell committed Sep 25, 2024
1 parent 5fb0ff9 commit 0c234bb
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ package org.apache.spark.sql.connect
import scala.language.implicitConversions

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.sql._
import org.apache.spark.sql.internal.ProtoColumnNode

/**
* Conversions from sql interfaces to the Connect specific implementation.
*
* This class is mainly used by the implementation. In the case of connect it should be extremely
* rare that a developer needs these classes.
* This class is mainly used by the implementation. It is also meant to be used by extension
* developers.
*
* We provide both a trait and an object. The trait is useful in situations where an extension
* developer needs to use these conversions in a project covering multiple Spark versions. They
Expand All @@ -46,6 +48,40 @@ trait ConnectConversions {
implicit def castToImpl[K, V](
kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] =
kvds.asInstanceOf[KeyValueGroupedDataset[K, V]]

/**
* Create a [[Column]] from a [[proto.Expression]]
*
* This method is meant to be used by Connect plugins. We do not guarantee any compatibility
* between (minor) versions.
*/
@DeveloperApi
def column(expr: proto.Expression): Column = {
Column(ProtoColumnNode(expr))
}

/**
* Create a [[Column]] using a function that manipulates an [[proto.Expression.Builder]].
*
* This method is meant to be used by Connect plugins. We do not guarantee any compatibility
* between (minor) versions.
*/
@DeveloperApi
def column(f: proto.Expression.Builder => Unit): Column = {
val builder = proto.Expression.newBuilder()
f(builder)
column(builder.build())
}

/**
* Implicit helper that makes it easy to construct a Column from an Expression or an Expression
* builder. This allows developers to create a Column in the same way as in earlier versions of
* Spark (before 4.0).
*/
@DeveloperApi
implicit class ColumnConstructorExt(val c: Column.type) {
def apply(e: proto.Expression): Column = column(e)
}
}

object ConnectConversions extends ConnectConversions
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,12 @@

package org.apache.spark

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.internal.ProtoColumnNode

package object sql {
type DataFrame = Dataset[Row]

private[sql] def encoderFor[E: Encoder]: AgnosticEncoder[E] = {
implicitly[Encoder[E]].asInstanceOf[AgnosticEncoder[E]]
}

/**
* Create a [[Column]] from a [[proto.Expression]]
*
* This method is meant to be used by Connect plugins. We do not guarantee any compatility
* between (minor) versions.
*/
@DeveloperApi
def column(expr: proto.Expression): Column = {
Column(ProtoColumnNode(expr))
}

/**
* Creat a [[Column]] using a function that manipulates an [[proto.Expression.Builder]].
*
* This method is meant to be used by Connect plugins. We do not guarantee any compatility
* between (minor) versions.
*/
@DeveloperApi
def column(f: proto.Expression.Builder => Unit): Column = {
val builder = proto.Expression.newBuilder()
f(builder)
column(builder.build())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.avro.{functions => avroFn}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -122,7 +121,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
(attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) {
replaceCol(attr, replacementMap)
} else {
column(attr)
Column(attr)
}
}
df.select(projections : _*)
Expand All @@ -131,7 +130,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
protected def fillMap(values: Seq[(String, Any)]): DataFrame = {
// Error handling
val attrToValue = AttributeMap(values.map { case (colName, replaceValue) =>
// Check column name exists
// Check Column name exists
val attr = df.resolve(colName) match {
case a: Attribute => a
case _ => throw QueryExecutionErrors.nestedFieldUnsupportedError(colName)
Expand All @@ -155,7 +154,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
case v: jl.Integer => fillCol[Integer](attr, v)
case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue())
case v: String => fillCol[String](attr, v)
}.getOrElse(column(attr))
}.getOrElse(Column(attr))
}
df.select(projections : _*)
}
Expand All @@ -165,7 +164,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
* with `replacement`.
*/
private def fillCol[T](attr: Attribute, replacement: T): Column = {
fillCol(attr.dataType, attr.name, column(attr), replacement)
fillCol(attr.dataType, attr.name, Column(attr), replacement)
}

/**
Expand All @@ -192,7 +191,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(Literal(source), buildExpr(target))
}.toSeq
column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
}

private def convertToDouble(v: Any): Double = v match {
Expand All @@ -219,7 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
val predicate = AtLeastNNonNulls(minNonNulls.getOrElse(cols.size), cols)
df.filter(column(predicate))
df.filter(Column(predicate))
}

private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame = {
Expand Down Expand Up @@ -255,9 +254,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
}
// Only fill if the column is part of the cols list.
if (typeMatches && cols.exists(_.semanticEquals(col))) {
fillCol(col.dataType, col.name, column(col), value)
fillCol(col.dataType, col.name, Column(col), value)
} else {
column(col)
Column(col)
}
}
df.select(projections : _*)
Expand Down
29 changes: 18 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Data
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.TypedAggUtils.withInputType
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -303,7 +302,7 @@ class Dataset[T] private[sql](
truncate: Int): Seq[Seq[String]] = {
val newDf = commandResultOptimized.toDF()
val castCols = newDf.logicalPlan.output.map { col =>
column(ToPrettyString(col))
Column(ToPrettyString(col))
}
val data = newDf.select(castCols: _*).take(numRows + 1)

Expand Down Expand Up @@ -505,7 +504,7 @@ class Dataset[T] private[sql](
s"New column names (${colNames.size}): " + colNames.mkString(", "))

val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) =>
column(oldAttribute).as(newName)
Column(oldAttribute).as(newName)
}
select(newCols : _*)
}
Expand Down Expand Up @@ -760,18 +759,18 @@ class Dataset[T] private[sql](
/** @inheritdoc */
def col(colName: String): Column = colName match {
case "*" =>
column(ResolvedStar(queryExecution.analyzed.output))
Column(ResolvedStar(queryExecution.analyzed.output))
case _ =>
if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) {
colRegex(colName)
} else {
column(addDataFrameIdToCol(resolve(colName)))
Column(addDataFrameIdToCol(resolve(colName)))
}
}

/** @inheritdoc */
def metadataColumn(colName: String): Column =
column(queryExecution.analyzed.getMetadataAttributeByName(colName))
Column(queryExecution.analyzed.getMetadataAttributeByName(colName))

// Attach the dataset id and column position to the column reference, so that we can detect
// ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`.
Expand All @@ -797,11 +796,11 @@ class Dataset[T] private[sql](
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
colName match {
case ParserUtils.escapedIdentifier(columnNameRegex) =>
column(UnresolvedRegex(columnNameRegex, None, caseSensitive))
Column(UnresolvedRegex(columnNameRegex, None, caseSensitive))
case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) =>
column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive))
Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive))
case _ =>
column(addDataFrameIdToCol(resolve(colName)))
Column(addDataFrameIdToCol(resolve(colName)))
}
}

Expand Down Expand Up @@ -1194,7 +1193,7 @@ class Dataset[T] private[sql](
resolver(field.name, colName)
} match {
case Some((colName: String, col: Column)) => col.as(colName)
case _ => column(field)
case _ => Column(field)
}
}

Expand Down Expand Up @@ -1264,7 +1263,7 @@ class Dataset[T] private[sql](
val allColumns = queryExecution.analyzed.output
val remainingCols = allColumns.filter { attribute =>
colNames.forall(n => !resolver(attribute.name, n))
}.map(attribute => column(attribute))
}.map(attribute => Column(attribute))
if (remainingCols.size == allColumns.size) {
toDF()
} else {
Expand Down Expand Up @@ -1975,6 +1974,14 @@ class Dataset[T] private[sql](
// For Python API
////////////////////////////////////////////////////////////////////////////

/**
* It adds a new long column with the name `name` that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
select(Column(DistributedSequenceID()).alias(name), col("*"))
}

/**
* Converts a JavaRDD to a PythonRDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import scala.language.implicitConversions

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.internal.ExpressionUtils

/**
* Conversions from sql interfaces to the Classic specific implementation.
*
* This class is mainly used by the implementation, but is also meant to be used by extension
* This class is mainly used by the implementation. It is also meant to be used by extension
* developers.
*
* We provide both a trait and an object. The trait is useful in situations where an extension
Expand All @@ -45,6 +47,13 @@ trait ClassicConversions {

implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V])
: KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]]

/**
* Helper that makes it easy to construct a Column from an Expression.
*/
implicit class ColumnConstructorExt(val c: Column.type) {
def apply(e: Expression): Column = ExpressionUtils.column(e)
}
}

object ClassicConversions extends ClassicConversions
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends
sqlConf.contains(key)
}

private def requireNonStaticConf(key: String): Unit = {
private[sql] def requireNonStaticConf(key: String): Unit = {
if (SQLConf.isStaticConfigKey(key)) {
throw QueryCompilationErrors.cannotModifyValueOfStaticConfigError(key)
}
Expand Down

0 comments on commit 0c234bb

Please sign in to comment.