Skip to content

Commit 04aeb98

Browse files
committed
wip
1 parent 8a70e1f commit 04aeb98

File tree

5 files changed

+178
-117
lines changed

5 files changed

+178
-117
lines changed

quixstreams/state/rocksdb/metadata.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
PROCESSED_OFFSET_KEY = b"__topic_offset__"
22
CHANGELOG_OFFSET_KEY = b"__changelog_offset__"
3+
4+
LATEST_TIMESTAMPS_CF_NAME = "__latest-timestamps__"
5+
LATEST_TIMESTAMP_KEY = b"__latest_timestamp__"

quixstreams/state/rocksdb/partition.py

+6
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,9 @@ def _update_changelog_offset(self, batch: WriteBatch, offset: int):
375375
int_to_int64_bytes(offset),
376376
self.get_column_family_handle(METADATA_CF_NAME),
377377
)
378+
379+
def _ensure_column_family(self, cf_name: str) -> None:
380+
try:
381+
self.get_column_family(cf_name)
382+
except ColumnFamilyDoesNotExist:
383+
self.create_column_family(cf_name)

quixstreams/state/rocksdb/timestamped.py

+122-24
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
1-
from typing import Any, Optional
1+
from typing import Any, Optional, Union, cast
22

33
from quixstreams.state.base.transaction import (
44
PartitionTransaction,
55
PartitionTransactionStatus,
66
validate_transaction_status,
77
)
88
from quixstreams.state.metadata import SEPARATOR
9-
from quixstreams.state.serialization import int_to_int64_bytes, serialize
9+
from quixstreams.state.recovery import ChangelogProducer
10+
from quixstreams.state.rocksdb.metadata import (
11+
LATEST_TIMESTAMP_KEY,
12+
LATEST_TIMESTAMPS_CF_NAME,
13+
)
14+
from quixstreams.state.rocksdb.types import RocksDBOptionsType
15+
from quixstreams.state.rocksdb.windowed.transaction import TimestampsCache
16+
from quixstreams.state.serialization import (
17+
DumpsFunc,
18+
LoadsFunc,
19+
int_to_int64_bytes,
20+
serialize,
21+
)
1022

1123
from .partition import RocksDBStorePartition
1224
from .store import RocksDBStore
@@ -18,7 +30,6 @@
1830
)
1931

2032
DAYS_7 = 7 * 24 * 60 * 60 * 1000
21-
EXPIRATION_COUNTER = 0
2233

2334

2435
class TimestampedPartitionTransaction(PartitionTransaction):
@@ -30,12 +41,26 @@ class TimestampedPartitionTransaction(PartitionTransaction):
3041
It interacts with both an in-memory update cache and the persistent RocksDB store.
3142
"""
3243

33-
# Override the type hint from the parent class (`PartitionTransaction`).
34-
# This informs type checkers like mypy that in this specific subclass,
35-
# `_partition` is a `TimestampedStorePartition` (defined below),
36-
# which has methods like `.iter_items()` that the base type might lack.
37-
# The string quotes are necessary for the forward reference.
38-
_partition: "TimestampedStorePartition"
44+
def __init__(
45+
self,
46+
partition: "TimestampedStorePartition",
47+
dumps: DumpsFunc,
48+
loads: LoadsFunc,
49+
changelog_producer: Optional[ChangelogProducer] = None,
50+
):
51+
super().__init__(
52+
partition=partition,
53+
dumps=dumps,
54+
loads=loads,
55+
changelog_producer=changelog_producer,
56+
)
57+
self._partition: TimestampedStorePartition = cast(
58+
"TimestampedStorePartition", self._partition
59+
)
60+
self._latest_timestamps: TimestampsCache = TimestampsCache(
61+
key=LATEST_TIMESTAMP_KEY,
62+
cf_name=LATEST_TIMESTAMPS_CF_NAME,
63+
)
3964

4065
@validate_transaction_status(PartitionTransactionStatus.STARTED)
4166
def get_last(
@@ -60,13 +85,20 @@ def get_last(
6085
:param cf_name: The column family name.
6186
:return: The deserialized value if found, otherwise None.
6287
"""
63-
global EXPIRATION_COUNTER
6488

6589
prefix = self._ensure_bytes(prefix)
6690

91+
latest_timestamp = max(
92+
self._get_timestamp(
93+
prefix=prefix, cache=self._latest_timestamps, default=0
94+
),
95+
timestamp,
96+
)
97+
6798
# Negative retention is not allowed
68-
lower_bound_timestamp = max(timestamp - retention_ms, 0)
69-
lower_bound = self._serialize_key(lower_bound_timestamp, prefix)
99+
lower_bound = self._serialize_key(
100+
max(latest_timestamp - retention_ms, 0), prefix
101+
)
70102
# +1 because upper bound is exclusive
71103
upper_bound = self._serialize_key(timestamp + 1, prefix)
72104

@@ -101,15 +133,18 @@ def get_last(
101133
# iterating backwards from the upper bound.
102134
break
103135

104-
if not EXPIRATION_COUNTER % 1000:
105-
self._expire(lower_bound_timestamp, prefix, cf_name=cf_name)
106-
EXPIRATION_COUNTER += 1
107-
108136
return self._deserialize_value(value) if value is not None else None
109137

110138
@validate_transaction_status(PartitionTransactionStatus.STARTED)
111-
def set(self, timestamp: int, value: Any, prefix: Any, cf_name: str = "default"):
112-
"""Set a value associated with a prefix and timestamp.
139+
def set_for_timestamp(
140+
self,
141+
timestamp: int,
142+
value: Any,
143+
prefix: Any,
144+
retention_ms: int = DAYS_7,
145+
cf_name: str = "default",
146+
):
147+
"""Set a value for the timestamp.
113148
114149
This method acts as a proxy, passing the provided `timestamp` and `prefix`
115150
to the parent `set` method. The parent method internally serializes these
@@ -122,8 +157,16 @@ def set(self, timestamp: int, value: Any, prefix: Any, cf_name: str = "default")
122157
"""
123158
prefix = self._ensure_bytes(prefix)
124159
super().set(timestamp, value, prefix, cf_name=cf_name)
160+
self._expire(
161+
timestamp=timestamp,
162+
prefix=prefix,
163+
retention_ms=retention_ms,
164+
cf_name=cf_name,
165+
)
125166

126-
def _expire(self, timestamp: int, prefix: bytes, cf_name: str = "default"):
167+
def _expire(
168+
self, timestamp: int, prefix: bytes, retention_ms: int, cf_name: str = "default"
169+
):
127170
"""
128171
Delete all entries for a given prefix with timestamps less than the
129172
provided timestamp.
@@ -136,11 +179,23 @@ def _expire(self, timestamp: int, prefix: bytes, cf_name: str = "default"):
136179
:param prefix: The key prefix.
137180
:param cf_name: Column family name.
138181
"""
139-
key = self._serialize_key(timestamp, prefix)
140182

141-
cached = self._update_cache.get_updates_for_prefix(
183+
latest_timestamp = max(
184+
self._get_timestamp(
185+
prefix=prefix, cache=self._latest_timestamps, default=0
186+
),
187+
timestamp,
188+
)
189+
self._set_timestamp(
190+
cache=self._latest_timestamps,
142191
prefix=prefix,
143-
cf_name=cf_name,
192+
timestamp_ms=latest_timestamp,
193+
)
194+
195+
key = self._serialize_key(max(timestamp - retention_ms, 0), prefix)
196+
197+
cached = self._update_cache.get_updates_for_prefix(
198+
prefix=prefix, cf_name=cf_name
144199
)
145200
# Cast to list to avoid RuntimeError: dictionary changed size during iteration
146201
for cached_key in list(cached):
@@ -160,8 +215,42 @@ def _ensure_bytes(self, prefix: Any) -> bytes:
160215
return prefix
161216
return serialize(prefix, dumps=self._dumps)
162217

163-
def _serialize_key(self, timestamp: int, prefix: bytes) -> bytes:
164-
return prefix + SEPARATOR + int_to_int64_bytes(timestamp)
218+
def _serialize_key(self, key: Union[int, bytes], prefix: bytes) -> bytes:
219+
match key:
220+
case int():
221+
return prefix + SEPARATOR + int_to_int64_bytes(key)
222+
case bytes():
223+
return prefix + SEPARATOR + key
224+
case _:
225+
raise ValueError(f"Invalid key type: {type(key)}")
226+
227+
def _get_timestamp(
228+
self, cache: TimestampsCache, prefix: bytes, default: Any = None
229+
) -> Any:
230+
cached_ts = cache.timestamps.get(prefix)
231+
if cached_ts is not None:
232+
return cached_ts
233+
234+
stored_ts = self.get(
235+
key=cache.key,
236+
prefix=prefix,
237+
cf_name=cache.cf_name,
238+
default=default,
239+
)
240+
if stored_ts is not None and not isinstance(stored_ts, int):
241+
raise ValueError(f"invalid timestamp {stored_ts}")
242+
243+
cache.timestamps[prefix] = stored_ts or default
244+
return stored_ts
245+
246+
def _set_timestamp(self, cache: TimestampsCache, prefix: bytes, timestamp_ms: int):
247+
cache.timestamps[prefix] = timestamp_ms
248+
self.set(
249+
key=cache.key,
250+
value=timestamp_ms,
251+
prefix=prefix,
252+
cf_name=cache.cf_name,
253+
)
165254

166255

167256
class TimestampedStorePartition(RocksDBStorePartition):
@@ -174,6 +263,15 @@ class TimestampedStorePartition(RocksDBStorePartition):
174263

175264
partition_transaction_class = TimestampedPartitionTransaction
176265

266+
def __init__(
267+
self,
268+
path: str,
269+
options: Optional[RocksDBOptionsType] = None,
270+
changelog_producer: Optional[ChangelogProducer] = None,
271+
) -> None:
272+
super().__init__(path, options=options, changelog_producer=changelog_producer)
273+
self._ensure_column_family(LATEST_TIMESTAMPS_CF_NAME)
274+
177275

178276
class TimestampedStore(RocksDBStore):
179277
"""

quixstreams/state/rocksdb/windowed/partition.py

-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
from typing import Iterator, Optional, cast
33

4-
from quixstreams.state.exceptions import ColumnFamilyDoesNotExist
54
from quixstreams.state.recovery import ChangelogProducer
65

76
from ..partition import RocksDBStorePartition
@@ -60,9 +59,3 @@ def iter_keys(self, cf_name: str = "default") -> Iterator[bytes]:
6059
"""
6160
cf_dict = self.get_column_family(cf_name)
6261
return cast(Iterator[bytes], cf_dict.keys())
63-
64-
def _ensure_column_family(self, cf_name: str):
65-
try:
66-
self.get_column_family(cf_name)
67-
except ColumnFamilyDoesNotExist:
68-
self.create_column_family(cf_name)

0 commit comments

Comments
 (0)