diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index c8f68dfd0388f..5acf7cdc9b975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -624,6 +624,8 @@ class UnsafeRowDataEncoder( decodeToUnsafeRow(bytes, reusedKeyRow) case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) => decodeToUnsafeRow(bytes, numFields = numColsPrefixKey) + case _: TimestampAsPrefixKeyStateEncoderSpec | _: TimestampAsPostfixKeyStateEncoderSpec => + decodeToUnsafeRow(bytes, numFields = keySchema.length - 1) case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey") } } @@ -748,8 +750,15 @@ class AvroStateEncoder( ) // Avro schema used by the avro encoders - private lazy val keyAvroType: Schema = SchemaConverters.toAvroTypeWithDefaults(keySchema) - private lazy val keyProj = UnsafeProjection.create(keySchema) + // For timestamp specs, the key part excludes the timestamp column (always the last field). + private lazy val effectiveKeySchema: StructType = keyStateEncoderSpec match { + case TimestampAsPrefixKeyStateEncoderSpec(s) => StructType(s.dropRight(1)) + case TimestampAsPostfixKeyStateEncoderSpec(s) => StructType(s.dropRight(1)) + case _ => keySchema + } + private lazy val keyAvroType: Schema = + SchemaConverters.toAvroTypeWithDefaults(effectiveKeySchema) + private lazy val keyProj = UnsafeProjection.create(effectiveKeySchema) private lazy val valueAvroType: Schema = SchemaConverters.toAvroTypeWithDefaults(valueSchema) private lazy val valueProj = UnsafeProjection.create(valueSchema) @@ -847,8 +856,10 @@ class AvroStateEncoder( } } StructType(remainingSchema) - case _ => - throw unsupportedOperationForKeyStateEncoder("createAvroEnc") + case TimestampAsPrefixKeyStateEncoderSpec(schema) => + StructType(schema.dropRight(1)) + case TimestampAsPostfixKeyStateEncoderSpec(schema) => + StructType(schema.dropRight(1)) } // Handle suffix key schema for prefix scan case @@ -1005,6 +1016,11 @@ class AvroStateEncoder( StateSchemaIdRow(currentKeySchemaId, avroRow)) case PrefixKeyScanStateEncoderSpec(_, _) => encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, prefixKeyAvroType, out) + case _: TimestampAsPrefixKeyStateEncoderSpec | _: TimestampAsPostfixKeyStateEncoderSpec => + val avroRow = + encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out) + encodeWithStateSchemaId( + StateSchemaIdRow(currentKeySchemaId, avroRow)) case _ => throw unsupportedOperationForKeyStateEncoder("encodeKey") } prependVersionByte(keyBytes) @@ -1179,6 +1195,10 @@ class AvroStateEncoder( case PrefixKeyScanStateEncoderSpec(_, _) => decodeFromAvroToUnsafeRow( bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj) + case _: TimestampAsPrefixKeyStateEncoderSpec | _: TimestampAsPostfixKeyStateEncoderSpec => + val schemaIdRow = decodeStateSchemaIdRow(bytes) + decodeFromAvroToUnsafeRow( + schemaIdRow.bytes, avroEncoder.keyDeserializer, keyAvroType, keyProj) case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey") } } @@ -1782,9 +1802,7 @@ abstract class TimestampKeyStateEncoder( rowBytes, Platform.BYTE_ARRAY_OFFSET, rowBytesLength ) - // The encoded row does not include the timestamp (it's stored separately), - // so decode with keySchema.length - 1 fields. - dataEncoder.decodeToUnsafeRow(rowBytes, keySchema.length - 1) + dataEncoder.decodeKey(rowBytes) } // NOTE: We reuse the ByteBuffer to avoid allocating a new one for every encoding/decoding, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 8bb1f609b2b44..ec0b8733ec67f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -673,7 +673,10 @@ case class RangeKeyScanStateEncoderSpec( } } -/** The encoder specification for [[TimestampAsPrefixKeyStateEncoder]]. */ +/** + * The encoder specification for [[TimestampAsPrefixKeyStateEncoder]]. + * The encoder expects the provided key schema to have [original key fields..., timestamp field]. + */ case class TimestampAsPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec { @@ -688,7 +691,10 @@ case class TimestampAsPrefixKeyStateEncoderSpec(keySchema: StructType) } } -/** The encoder specification for [[TimestampAsPostfixKeyStateEncoder]]. */ +/** + * The encoder specification for [[TimestampAsPostfixKeyStateEncoder]]. + * The encoder expects the provided key schema to have [original key fields..., timestamp field]. + */ case class TimestampAsPostfixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index a65565d8b245e..a9540a4ad623e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -68,9 +68,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession private def newDir(): String = Utils.createTempDir().getCanonicalPath - // TODO: [SPARK-55145] Address the new state format with Avro and enable the test with Avro - // encoding - Seq("unsaferow").foreach { encoding => + Seq("unsaferow", "avro").foreach { encoding => Seq("prefix", "postfix").foreach { encoderType => test(s"Event time as $encoderType: basic put and get operations (encoding = $encoding)") { tryWithProviderResource( @@ -223,9 +221,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } } - // TODO: [SPARK-55145] Address the new state format with Avro and enable the test with Avro - // encoding - Seq("unsaferow").foreach { encoding => + Seq("unsaferow", "avro").foreach { encoding => test(s"Event time as prefix: iterator operations (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( @@ -558,9 +554,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } } - // TODO: [SPARK-55145] Address the new state format with Avro and enable the test with Avro - // encoding - Seq("unsaferow").foreach { encoding => + Seq("unsaferow", "avro").foreach { encoding => Seq("prefix", "postfix").foreach { encoderType => Seq(false, true).foreach { useMultipleValuesPerKey => val multiValueSuffix = if (useMultipleValuesPerKey) " and multiple values" else ""