Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
stateStoreReaderInfo.stateSchemaProviderOpt,
stateStoreReaderInfo.joinColFamilyOpt,
stateStoreReaderInfo.allColumnFamiliesReaderInfo)
stateStoreReaderInfo.allColumnFamiliesReaderInfo,
stateStoreReaderInfo.joinStateFormatVersion)
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
Expand Down Expand Up @@ -162,9 +163,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging

/**
* Returns true if this is a read-all-column-families request for a stream-stream join
* that uses virtual column families (state format version 3).
* that uses virtual column families (state format version >= 3).
*/
private def isReadAllColFamiliesOnJoinV3(
private def isReadAllColFamiliesOnJoinWithVCF(
sourceOptions: StateSourceOptions,
storeMetadata: Array[StateMetadataTableEntry]): Boolean = {
sourceOptions.internalOnlyReadAllColumnFamilies &&
Expand Down Expand Up @@ -243,9 +244,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
opMetadata.operatorName match {
case opName: String if opName ==
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME =>
// Verify that the storename is valid
val possibleStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(
LeftSide, RightSide)
LeftSide, RightSide) ++
SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide)
if (!possibleStoreNames.contains(name)) {
val errorMsg = s"Store name $name not allowed for join operator. Allowed names are " +
s"$possibleStoreNames. " +
Expand Down Expand Up @@ -393,7 +394,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
// However, Join V3 does not have a "default" column family. Therefore, we pick the first
// schema as resultSchema which will be used as placeholder schema for default schema
// in StatePartitionAllColumnFamiliesReader
val resultSchema = if (isReadAllColFamiliesOnJoinV3(sourceOptions, storeMetadata)) {
val resultSchema = if (isReadAllColFamiliesOnJoinWithVCF(sourceOptions, storeMetadata)) {
stateSchema.head
} else {
stateSchema.filter(_.colFamilyName == stateVarName).head
Expand All @@ -408,17 +409,18 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
}
}

val joinFormatVersion = getStateFormatVersion(
storeMetadata,
sourceOptions.resolvedCpLocation,
sourceOptions.batchId
)

val allColFamilyReaderInfoOpt: Option[AllColumnFamiliesReaderInfo] =
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
assert(storeMetadata.nonEmpty, "storeMetadata shouldn't be empty")
val operatorName = storeMetadata.head.operatorName
val stateFormatVersion = getStateFormatVersion(
storeMetadata,
sourceOptions.resolvedCpLocation,
sourceOptions.batchId
)
Some(AllColumnFamiliesReaderInfo(
stateStoreColFamilySchemas, stateVariableInfos, operatorName, stateFormatVersion))
stateStoreColFamilySchemas, stateVariableInfos, operatorName, joinFormatVersion))
} else {
None
}
Expand All @@ -428,7 +430,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
transformWithStateVariableInfoOpt,
stateSchemaProvider,
joinColFamilyOpt,
allColFamilyReaderInfoOpt
allColFamilyReaderInfoOpt,
joinFormatVersion
)
}

Expand Down Expand Up @@ -819,7 +822,8 @@ case class StateStoreReaderInfo(
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String], // Only used for join op with state format v3
// List of all column family schemas - used when internalOnlyReadAllColumnFamilies=true
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
joinStateFormatVersion: Option[Int] = None
)

object StateDataSource {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils, StatePartitionKeyExtractorFactory}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
Expand Down Expand Up @@ -138,6 +139,10 @@ abstract class StatePartitionReaderBase(
getStoreUniqueId(partition.sourceOptions.endOperatorStateUniqueIds)
}

protected val isJoinV4MultiValuedCF: Boolean = joinColFamilyOpt.exists { cfName =>
SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide).contains(cfName)
}

protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
Expand All @@ -146,7 +151,7 @@ abstract class StatePartitionReaderBase(
val useColFamilies = stateVariableInfoOpt.isDefined || joinColFamilyOpt.isDefined

val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt,
StateVariableType.ListState)
StateVariableType.ListState) || isJoinV4MultiValuedCF

val provider = StateStoreProvider.createAndInit(
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
Expand Down Expand Up @@ -249,6 +254,12 @@ class StatePartitionReader(
val stateVarType = stateVariableInfo.stateVariableType
SchemaUtil.processStateEntries(stateVarType, colFamilyName, store,
keySchema, partition.partition, partition.sourceOptions)
} else if (isJoinV4MultiValuedCF) {
store
.iteratorWithMultiValues(colFamilyName)
.map { pair =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
}
} else {
store
.iterator(colFamilyName)
Expand Down Expand Up @@ -333,6 +344,13 @@ class StatePartitionAllColumnFamiliesReader(
StateVariableType.ListState)
}

private val v4JoinCFNames: Set[String] =
SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide).toSet

private def isMultiValuedCF(colFamilyName: String): Boolean = {
isListType(colFamilyName) || v4JoinCFNames.contains(colFamilyName)
}

override protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
Expand Down Expand Up @@ -386,7 +404,7 @@ class StatePartitionAllColumnFamiliesReader(
case _ =>
val isInternal =
StateStoreColumnFamilySchemaUtils.isInternalColFamily(cfSchema.colFamilyName)
val useMultipleValuesPerKey = isListType(cfSchema.colFamilyName)
val useMultipleValuesPerKey = isMultiValuedCF(cfSchema.colFamilyName)
require(cfSchema.keyStateEncoderSpec.isDefined,
s"keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}")
stateStore.createColFamilyIfAbsent(
Expand All @@ -410,15 +428,11 @@ class StatePartitionAllColumnFamiliesReader(
.filter(schema => !isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
.flatMap { cfSchema =>
val extractor = cfPartitionKeyExtractors(cfSchema.colFamilyName)
if (isListType(cfSchema.colFamilyName)) {
store.iterator(cfSchema.colFamilyName).flatMap(
pair =>
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
value =>
SchemaUtil.unifyStateRowPairAsRawBytes(
(pair.key, value), cfSchema.colFamilyName, extractor)
}
)
if (isMultiValuedCF(cfSchema.colFamilyName)) {
store.iteratorWithMultiValues(cfSchema.colFamilyName).map { pair =>
SchemaUtil.unifyStateRowPairAsRawBytes(
(pair.key, pair.value), cfSchema.colFamilyName, extractor)
}
} else {
store.iterator(cfSchema.colFamilyName).map { pair =>
SchemaUtil.unifyStateRowPairAsRawBytes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ class StateScanBuilder(
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends ScanBuilder {
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
joinStateFormatVersion: Option[Int] = None) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
batchNumPartitions, keyStateEncoderSpec,
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt, allColumnFamiliesReaderInfo)
joinColFamilyOpt, allColumnFamiliesReaderInfo, joinStateFormatVersion)
}

/** An implementation of [[InputPartition]] for State Store data source. */
Expand All @@ -73,7 +74,8 @@ class StateScan(
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo])
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
joinStateFormatVersion: Option[Int] = None)
extends Scan with Batch {

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
Expand Down Expand Up @@ -138,7 +140,7 @@ class StateScan(
sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, LeftSide,
oldSchemaFilePaths, excludeAuxColumns = false)
new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf,
hadoopConfBroadcast.value, userFacingSchema, stateSchema)
hadoopConfBroadcast.value, userFacingSchema, stateSchema, joinStateFormatVersion)

case JoinSideValues.right =>
val userFacingSchema = schema
Expand All @@ -148,7 +150,7 @@ class StateScan(
sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, RightSide,
oldSchemaFilePaths, excludeAuxColumns = false)
new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf,
hadoopConfBroadcast.value, userFacingSchema, stateSchema)
hadoopConfBroadcast.value, userFacingSchema, stateSchema, joinStateFormatVersion)

case JoinSideValues.none =>
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class StateTable(
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None)
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None,
joinStateFormatVersion: Option[Int] = None)
extends Table with SupportsRead with SupportsMetadataColumns {

import StateTable._
Expand Down Expand Up @@ -90,7 +91,7 @@ class StateTable(
new StateScanBuilder(session, schema, sourceOptions, stateConf,
batchNumPartitions, keyStateEncoderSpec,
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt, allColumnFamiliesReaderInfo)
joinColFamilyOpt, allColumnFamiliesReaderInfo, joinStateFormatVersion)

override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ object StreamStreamJoinStateHelper {
.add("value", valueSchema)
}

// Returns whether the checkpoint uses stateFormatVersion 3 which uses VCF for the join.
// Returns whether the checkpoint uses VCF for the join (stateFormatVersion >= 3).
def usesVirtualColumnFamilies(
hadoopConf: Configuration,
stateCheckpointLocation: String,
operatorId: Int): Boolean = {
// If the schema exists for operatorId/partitionId/left-keyToNumValues, it is not
// stateFormatVersion 3.
// stateFormatVersion >= 3 (which uses VCF).
val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).toList.head)
Expand All @@ -76,12 +76,12 @@ object StreamStreamJoinStateHelper {

val newHadoopConf = session.sessionState.newHadoopConf()
val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
// KeyToNumValuesType, KeyWithIndexToValueType
val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList

val (keySchema, valueSchema) =
if (!usesVirtualColumnFamilies(
newHadoopConf, stateCheckpointLocation, operatorId)) {
// v1/v2: separate state stores per store type
val storeIdForKeyToNumValues = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, storeNames(0))
val providerIdForKeyToNumValues = new StateStoreProviderId(storeIdForKeyToNumValues,
Expand All @@ -105,29 +105,40 @@ object StreamStreamJoinStateHelper {

(kSchema, vSchema)
} else {
// v3/v4: single state store with virtual column families
val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, StateStoreId.DEFAULT_STORE_NAME)
val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())

val manager = new StateSchemaCompatibilityChecker(
providerId, newHadoopConf, oldSchemaFilePaths, createSchemaDir = false)
val kSchema = manager.readSchemaFile().find { schema =>
schema.colFamilyName == storeNames(0)
}.map(_.keySchema).get
val schemas = manager.readSchemaFile()

val vSchema = manager.readSchemaFile().find { schema =>
schema.colFamilyName == storeNames(1)
}.map(_.valueSchema).get
// Try v3 CF names first; if not found, use v4 CF names
val v3Names = storeNames
val v4Names = SymmetricHashJoinStateManager.allStateStoreNamesV4(side).toList

val primaryCfName = schemas.find(_.colFamilyName == v3Names(1)) match {
case Some(_) => v3Names(1) // v3: keyWithIndexToValue
case None => v4Names(0) // v4: keyWithTsToValues
}
val keyCfName = schemas.find(_.colFamilyName == v3Names(0)) match {
case Some(_) => v3Names(0) // v3: keyToNumValues
case None => v4Names(0) // v4: keyWithTsToValues (key schema is the join key)
}

val kSchema = schemas.find(_.colFamilyName == keyCfName).map(_.keySchema).get
val vSchema = schemas.find(_.colFamilyName == primaryCfName).map(_.valueSchema).get

(kSchema, vSchema)
}

val maybeMatchedColumn = valueSchema.last

// remove internal column `matched` for format version >= 2
if (excludeAuxColumns
&& maybeMatchedColumn.name == "matched"
&& maybeMatchedColumn.dataType == BooleanType) {
// remove internal column `matched` for format version 2
(keySchema, StructType(valueSchema.dropRight(1)))
} else {
(keySchema, valueSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ class StreamStreamJoinStatePartitionReaderFactory(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
userFacingSchema: StructType,
stateSchema: StructType) extends PartitionReaderFactory {
stateSchema: StructType,
joinStateFormatVersion: Option[Int] = None) extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
new StreamStreamJoinStatePartitionReader(storeConf, hadoopConf,
partition.asInstanceOf[StateStoreInputPartition], userFacingSchema, stateSchema)
partition.asInstanceOf[StateStoreInputPartition], userFacingSchema, stateSchema,
joinStateFormatVersion)
}
}

Expand All @@ -54,7 +56,9 @@ class StreamStreamJoinStatePartitionReader(
hadoopConf: SerializableConfiguration,
partition: StateStoreInputPartition,
userFacingSchema: StructType,
stateSchema: StructType) extends PartitionReader[InternalRow] with Logging {
stateSchema: StructType,
joinStateFormatVersion: Option[Int] = None)
extends PartitionReader[InternalRow] with Logging {

private val keySchema = SchemaUtil.getSchemaAsDataType(stateSchema, "key")
.asInstanceOf[StructType]
Expand Down Expand Up @@ -112,28 +116,25 @@ class StreamStreamJoinStatePartitionReader(
endStateStoreCheckpointIds.right.keyWithIndexToValue
}

/*
* This is to handle the difference of schema across state format versions. The major difference
* is whether we have added new field(s) in addition to the fields from input schema.
*
* - version 1: no additional field
* - version 2: the field "matched" is added to the last
*/
private val (inputAttributes, formatVersion) = {
val maybeMatchedColumn = valueSchema.last
val (fields, version) = {
// If there is a matched column, version is either 2 or 3. We need to drop the matched
// column from the value schema to get the actual fields.
if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) {
// If checkpoint is using one store and virtual column families, version is 3
if (usesVirtualColumnFamilies) {
(valueSchema.dropRight(1), 3)
val (fields, version) = joinStateFormatVersion match {
// Use explicit format version when available from offset log
case Some(v) if v >= 2 =>
(valueSchema.dropRight(1), v)
case Some(1) =>
(valueSchema, 1)
// Fall back to heuristic-based detection for old checkpoints
case _ =>
if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) {
if (usesVirtualColumnFamilies) {
(valueSchema.dropRight(1), 3)
} else {
(valueSchema.dropRight(1), 2)
}
} else {
(valueSchema.dropRight(1), 2)
(valueSchema, 1)
}
} else {
(valueSchema, 1)
}
}

assert(fields.toArray.sameElements(userFacingValueSchema.fields),
Expand Down
Loading