Skip to content

Commit 568158b

Browse files
committed
Add Python KafkaSource.set_deserializer API
1 parent f6a077a commit 568158b

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

flink-python/pyflink/datastream/connectors/kafka.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444
'KafkaOffsetResetStrategy',
4545
'KafkaRecordSerializationSchema',
4646
'KafkaRecordSerializationSchemaBuilder',
47-
'KafkaTopicSelector'
47+
'KafkaTopicSelector',
48+
'KafkaRecordDeserializationSchema',
49+
'SimpleStringValueKafkaRecordDeserializationSchema'
4850
]
4951

5052

@@ -353,6 +355,38 @@ def ignore_failures_after_transaction_timeout(self) -> 'FlinkKafkaProducer':
353355

354356
# ---- KafkaSource ----
355357

358+
class KafkaRecordDeserializationSchema:
359+
"""
360+
Base class for KafkaRecordDeserializationSchema. The kafka record deserialization schema
361+
describes how to turn the byte messages delivered by Apache Kafka into data types (Java/
362+
Scala objects) that are processed by Flink.
363+
364+
In addition, the KafkaRecordDeserializationSchema describes the produced type which lets
365+
Flink create internal serializers and structures to handle the type.
366+
"""
367+
def __init__(self, j_kafka_record_deserialization_schema=None):
368+
self.j_kafka_record_deserialization_schema = j_kafka_record_deserialization_schema
369+
370+
371+
class SimpleStringValueKafkaRecordDeserializationSchema(KafkaRecordDeserializationSchema):
372+
"""
373+
Very simple deserialization schema for strings values. By default, the deserializer uses
374+
'UTF-8' for byte to string conversion.
375+
"""
376+
377+
def __init__(self, charset: str = 'UTF-8'):
378+
gate_way = get_gateway()
379+
j_char_set = gate_way.jvm.java.nio.charset.Charset.forName(charset)
380+
j_simple_string_serialization_schema = gate_way.jvm \
381+
.org.apache.flink.api.common.serialization.SimpleStringSchema(j_char_set)
382+
j_kafka_record_deserialization_schema = gate_way.jvm \
383+
.org.apache.flink.connector.kafka.source.reader.deserializer \
384+
.KafkaRecordDeserializationSchema.valueOnly(j_simple_string_serialization_schema)
385+
KafkaRecordDeserializationSchema.__init__(
386+
self, j_kafka_record_deserialization_schema=j_kafka_record_deserialization_schema)
387+
388+
389+
# ---- KafkaSource ----
356390

357391
class KafkaSource(Source):
358392
"""
@@ -611,6 +645,22 @@ def set_value_only_deserializer(self, deserialization_schema: DeserializationSch
611645
self._j_builder.setValueOnlyDeserializer(deserialization_schema._j_deserialization_schema)
612646
return self
613647

648+
def set_deserializer(
649+
self,
650+
kafka_record_deserialization_schema: KafkaRecordDeserializationSchema
651+
) -> 'KafkaSourceBuilder':
652+
"""
653+
Sets the :class:`~pyflink.datastream.connectors.kafka.KafkaRecordDeserializationSchema`
654+
for deserializing Kafka ConsumerRecords.
655+
656+
:param kafka_record_deserialization_schema: the :class:`KafkaRecordDeserializationSchema`
657+
to use for deserialization.
658+
:return: this KafkaSourceBuilder.
659+
"""
660+
self._j_builder.setDeserializer(
661+
kafka_record_deserialization_schema.j_kafka_record_deserialization_schema)
662+
return self
663+
614664
def set_client_id_prefix(self, prefix: str) -> 'KafkaSourceBuilder':
615665
"""
616666
Sets the client id prefix of this KafkaSource.

flink-python/pyflink/datastream/connectors/tests/test_kafka.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from pyflink.datastream.connectors.base import DeliveryGuarantee
3030
from pyflink.datastream.connectors.kafka import KafkaSource, KafkaTopicPartition, \
3131
KafkaOffsetsInitializer, KafkaOffsetResetStrategy, KafkaRecordSerializationSchema, KafkaSink, \
32-
FlinkKafkaProducer, FlinkKafkaConsumer
32+
FlinkKafkaProducer, FlinkKafkaConsumer, KafkaRecordDeserializationSchema, \
33+
SimpleStringValueKafkaRecordDeserializationSchema
3334
from pyflink.datastream.formats.avro import AvroRowDeserializationSchema, AvroRowSerializationSchema
3435
from pyflink.datastream.formats.csv import CsvRowDeserializationSchema, CsvRowSerializationSchema
3536
from pyflink.datastream.formats.json import JsonRowDeserializationSchema, JsonRowSerializationSchema
@@ -332,6 +333,22 @@ def _check(schema: DeserializationSchema, class_name: str):
332333
'org.apache.flink.formats.avro.AvroRowDeserializationSchema'
333334
)
334335

336+
def test_set_kafka_record_deserialization_schema(self):
337+
def _check(schema: KafkaRecordDeserializationSchema, java_class_name: str):
338+
source = KafkaSource.builder() \
339+
.set_bootstrap_servers('localhost:9092') \
340+
.set_topics('test_topic') \
341+
.set_deserializer(schema) \
342+
.build()
343+
kafka_record_deserialization_schema = get_field_value(source.get_java_function(),
344+
'deserializationSchema')
345+
self.assertEqual(kafka_record_deserialization_schema.getClass().getCanonicalName(),
346+
java_class_name)
347+
348+
_check(SimpleStringValueKafkaRecordDeserializationSchema(),
349+
'org.apache.flink.connector.kafka.source.reader.deserializer.'
350+
'KafkaValueOnlyDeserializationSchemaWrapper')
351+
335352
def _check_reader_handled_offsets_initializer(self,
336353
source: KafkaSource,
337354
offset: int,

0 commit comments

Comments
 (0)