Skip to content

CommitOffset non-stream #581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker-compose-tls.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3.9"
services:
ydb:
image: ydbplatform/local-ydb:trunk
image: ydbplatform/local-ydb:latest
restart: always
ports:
- 2136:2136
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3.3"
services:
ydb:
image: ydbplatform/local-ydb:trunk
image: ydbplatform/local-ydb:latest
restart: always
ports:
- 2136:2136
Expand Down
24 changes: 23 additions & 1 deletion tests/query/test_query_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,84 @@
import ydb


query = """SELECT $a AS value"""
query_template = "DECLARE $a as %s; SELECT $a AS value"


def test_select_implicit_int(pool: ydb.QuerySessionPool):
query = query_template % "Int64"
expected_value = 111
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_implicit_float(pool: ydb.QuerySessionPool):
query = query_template % "Double"
expected_value = 11.1
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == pytest.approx(actual_value)


def test_select_implicit_bool(pool: ydb.QuerySessionPool):
query = query_template % "Bool"
expected_value = False
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_implicit_str(pool: ydb.QuerySessionPool):
query = query_template % "Utf8"
expected_value = "text"
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_implicit_bytes(pool: ydb.QuerySessionPool):
query = query_template % "String"
expected_value = b"text"
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_implicit_list(pool: ydb.QuerySessionPool):
query = query_template % "List<Int64>"
expected_value = [1, 2, 3]
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_implicit_dict(pool: ydb.QuerySessionPool):
query = query_template % "Dict<Utf8, Int64>"
expected_value = {"a": 1, "b": 2}
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_implicit_list_nested(pool: ydb.QuerySessionPool):
query = query_template % "List<Dict<Utf8, Int64>>"
expected_value = [{"a": 1}, {"b": 2}]
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_implicit_dict_nested(pool: ydb.QuerySessionPool):
query = query_template % "Dict<Utf8, List<Int64>>"
expected_value = {"a": [1, 2, 3], "b": [4, 5]}
res = pool.execute_with_retries(query, parameters={"$a": expected_value})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_implicit_custom_type_raises(pool: ydb.QuerySessionPool):
query = query_template % "Struct"

class CustomClass:
pass

Expand All @@ -80,25 +91,29 @@ class CustomClass:


def test_select_implicit_empty_list_raises(pool: ydb.QuerySessionPool):
query = query_template % "List<Int64>"
expected_value = []
with pytest.raises(ValueError):
pool.execute_with_retries(query, parameters={"$a": expected_value})


def test_select_implicit_empty_dict_raises(pool: ydb.QuerySessionPool):
query = query_template % "Dict<Int64, Int64>"
expected_value = {}
with pytest.raises(ValueError):
pool.execute_with_retries(query, parameters={"$a": expected_value})


def test_select_explicit_primitive(pool: ydb.QuerySessionPool):
query = query_template % "Int64"
expected_value = 111
res = pool.execute_with_retries(query, parameters={"$a": (expected_value, ydb.PrimitiveType.Int64)})
actual_value = res[0].rows[0]["value"]
assert expected_value == actual_value


def test_select_explicit_list(pool: ydb.QuerySessionPool):
query = query_template % "List<Int64>"
expected_value = [1, 2, 3]
type_ = ydb.ListType(ydb.PrimitiveType.Int64)
res = pool.execute_with_retries(query, parameters={"$a": (expected_value, type_)})
Expand All @@ -107,6 +122,7 @@ def test_select_explicit_list(pool: ydb.QuerySessionPool):


def test_select_explicit_dict(pool: ydb.QuerySessionPool):
query = query_template % "Dict<Utf8, Utf8>"
expected_value = {"key": "value"}
type_ = ydb.DictType(ydb.PrimitiveType.Utf8, ydb.PrimitiveType.Utf8)
res = pool.execute_with_retries(query, parameters={"$a": (expected_value, type_)})
Expand All @@ -115,6 +131,7 @@ def test_select_explicit_dict(pool: ydb.QuerySessionPool):


def test_select_explicit_empty_list_not_raises(pool: ydb.QuerySessionPool):
query = query_template % "List<Int64>"
expected_value = []
type_ = ydb.ListType(ydb.PrimitiveType.Int64)
res = pool.execute_with_retries(query, parameters={"$a": (expected_value, type_)})
Expand All @@ -123,6 +140,7 @@ def test_select_explicit_empty_list_not_raises(pool: ydb.QuerySessionPool):


def test_select_explicit_empty_dict_not_raises(pool: ydb.QuerySessionPool):
query = query_template % "Dict<Utf8, Utf8>"
expected_value = {}
type_ = ydb.DictType(ydb.PrimitiveType.Utf8, ydb.PrimitiveType.Utf8)
res = pool.execute_with_retries(query, parameters={"$a": (expected_value, type_)})
Expand All @@ -131,6 +149,7 @@ def test_select_explicit_empty_dict_not_raises(pool: ydb.QuerySessionPool):


def test_select_typedvalue_full_primitive(pool: ydb.QuerySessionPool):
query = query_template % "Int64"
expected_value = 111
typed_value = ydb.TypedValue(expected_value, ydb.PrimitiveType.Int64)
res = pool.execute_with_retries(query, parameters={"$a": typed_value})
Expand All @@ -139,6 +158,7 @@ def test_select_typedvalue_full_primitive(pool: ydb.QuerySessionPool):


def test_select_typedvalue_implicit_primitive(pool: ydb.QuerySessionPool):
query = query_template % "Int64"
expected_value = 111
typed_value = ydb.TypedValue(expected_value)
res = pool.execute_with_retries(query, parameters={"$a": typed_value})
Expand All @@ -147,6 +167,8 @@ def test_select_typedvalue_implicit_primitive(pool: ydb.QuerySessionPool):


def test_select_typevalue_custom_type_raises(pool: ydb.QuerySessionPool):
query = query_template % "Struct"

class CustomClass:
pass

Expand Down
40 changes: 40 additions & 0 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,26 @@ async def test_read_and_commit_with_ack(self, driver, topic_with_messages, topic

assert message != batch.messages[0]

async def test_commit_offset_works(self, driver, topic_with_messages, topic_consumer):
for out in ["123", "456", "789", "0"]:
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
message = await reader.receive_message()
assert message.data.decode() == out

await driver.topic_client.commit_offset(
topic_with_messages, topic_consumer, message.partition_id, message.offset + 1
)

async def test_reader_reconnect_after_commit_offset(self, driver, topic_with_messages, topic_consumer):
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
for out in ["123", "456", "789", "0"]:
message = await reader.receive_message()
assert message.data.decode() == out

await driver.topic_client.commit_offset(
topic_with_messages, topic_consumer, message.partition_id, message.offset + 1
)

async def test_read_compressed_messages(self, driver, topic_path, topic_consumer):
async with driver.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer:
await writer.write("123")
Expand Down Expand Up @@ -183,6 +203,26 @@ def test_read_and_commit_with_ack(self, driver_sync, topic_with_messages, topic_

assert message != batch.messages[0]

def test_commit_offset_works(self, driver_sync, topic_with_messages, topic_consumer):
for out in ["123", "456", "789", "0"]:
with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader:
message = reader.receive_message()
assert message.data.decode() == out

driver_sync.topic_client.commit_offset(
topic_with_messages, topic_consumer, message.partition_id, message.offset + 1
)

def test_reader_reconnect_after_commit_offset(self, driver_sync, topic_with_messages, topic_consumer):
with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader:
for out in ["123", "456", "789", "0"]:
message = reader.receive_message()
assert message.data.decode() == out

driver_sync.topic_client.commit_offset(
topic_with_messages, topic_consumer, message.partition_id, message.offset + 1
)

def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer):
with driver_sync.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer:
writer.write("123")
Expand Down
1 change: 1 addition & 0 deletions ydb/_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class TopicService(object):
StreamRead = "StreamRead"
StreamWrite = "StreamWrite"
UpdateOffsetsInTransaction = "UpdateOffsetsInTransaction"
CommitOffset = "CommitOffset"


class QueryService(object):
Expand Down
16 changes: 16 additions & 0 deletions ydb/_grpc/grpcwrapper/ydb_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,22 @@ def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any:
return UpdateTokenResponse()


@dataclass
class CommitOffsetRequest(IToProto):
path: str
consumer: str
partition_id: int
offset: int

def to_proto(self) -> ydb_topic_pb2.CommitOffsetRequest:
return ydb_topic_pb2.CommitOffsetRequest(
path=self.path,
consumer=self.consumer,
partition_id=self.partition_id,
offset=self.offset,
)


########################################################################################################################
# StreamWrite
########################################################################################################################
Expand Down
4 changes: 4 additions & 0 deletions ydb/_topic_reader/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def _commit_get_offsets_range(self) -> OffsetsRange:
def alive(self) -> bool:
return not self._partition_session.closed

@property
def partition_id(self) -> int:
return self._partition_session.partition_id


@dataclass
class PartitionSession:
Expand Down
30 changes: 30 additions & 0 deletions ydb/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,21 @@ def tx_writer(

return TopicTxWriterAsyncIO(tx=tx, driver=self._driver, settings=settings, _client=self)

async def commit_offset(self, path: str, consumer: str, partition_id: int, offset: int) -> None:
req = _ydb_topic.CommitOffsetRequest(
path=path,
consumer=consumer,
partition_id=partition_id,
offset=offset,
)

await self._driver(
req.to_proto(),
_apis.TopicService.Stub,
_apis.TopicService.CommitOffset,
_wrap_operation,
)

def close(self):
if self._closed:
return
Expand Down Expand Up @@ -603,6 +618,21 @@ def tx_writer(

return TopicTxWriter(tx, self._driver, settings, _parent=self)

def commit_offset(self, path: str, consumer: str, partition_id: int, offset: int) -> None:
req = _ydb_topic.CommitOffsetRequest(
path=path,
consumer=consumer,
partition_id=partition_id,
offset=offset,
)

self._driver(
req.to_proto(),
_apis.TopicService.Stub,
_apis.TopicService.CommitOffset,
_wrap_operation,
)

def close(self):
if self._closed:
return
Expand Down
Loading