Skip to content

Commit 85559da

Browse files
committed
Replace cache-level iter-items with get_updates_for_prefix
1 parent 21e69ed commit 85559da

File tree

3 files changed

+49
-47
lines changed

3 files changed

+49
-47
lines changed

quixstreams/state/base/transaction.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import (
77
TYPE_CHECKING,
88
Any,
9-
Dict,
109
Generic,
1110
Optional,
1211
Set,
@@ -97,21 +96,6 @@ def get(
9796
# UNDEFINED to signify that
9897
return self._updated[cf_name][prefix].get(key, Marker.UNDEFINED)
9998

100-
def iter_items(
101-
self,
102-
prefix: bytes,
103-
backwards: bool = False,
104-
cf_name: str = "default",
105-
) -> list[tuple[bytes, bytes]]:
106-
"""
107-
Iterate over sorted items in the cache
108-
for the given prefix and column family.
109-
"""
110-
return sorted(
111-
self._updated[cf_name][prefix].items(),
112-
reverse=backwards,
113-
)
114-
11599
def set(self, key: bytes, value: bytes, prefix: bytes, cf_name: str = "default"):
116100
"""
117101
Set a value for the key.
@@ -150,15 +134,29 @@ def get_column_families(self) -> Set[str]:
150134
"""
151135
return set(self._updated.keys()) | set(self._deleted.keys())
152136

153-
def get_updates(self, cf_name: str = "default") -> Dict[bytes, Dict[bytes, bytes]]:
137+
def get_updates(self, cf_name: str = "default") -> dict[bytes, dict[bytes, bytes]]:
154138
"""
155139
Get all updated keys (excluding deleted)
156-
in the format "{<prefix>: {<key>: <value>}}".
140+
in the format "{<prefix>: {<key>: <value>, ...}, ...}".
157141
158142
:param: cf_name: column family name
159143
"""
160144
return self._updated.get(cf_name, {})
161145

146+
def get_updates_for_prefix(
147+
self,
148+
prefix: bytes,
149+
cf_name: str = "default",
150+
) -> dict[bytes, bytes]:
151+
"""
152+
Get all updated keys (excluding deleted)
153+
in the format "{<key>: <value>, ...}".
154+
155+
:param: prefix: key prefix
156+
:param: cf_name: column family name
157+
"""
158+
return self._updated.get(cf_name, {}).get(prefix, {})
159+
162160
def get_deletes(self, cf_name: str = "default") -> Set[bytes]:
163161
"""
164162
Get all deleted keys (excluding updated) as a set.

quixstreams/state/rocksdb/timestamped.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ def get_last(
6161
value: Optional[bytes] = None
6262

6363
deletes = self._update_cache.get_deletes(cf_name=cf_name)
64-
65-
cached = self._update_cache.iter_items(
66-
prefix=prefix,
67-
backwards=True,
68-
cf_name=cf_name,
64+
updates = self._update_cache.get_updates_for_prefix(
65+
cf_name=cf_name, prefix=prefix
6966
)
67+
68+
cached = sorted(updates.items(), reverse=True)
7069
for cached_key, cached_value in cached:
7170
if prefix < cached_key < key and cached_key not in deletes:
7271
value = cached_value
@@ -110,8 +109,12 @@ def set(self, timestamp: int, value: Any, prefix: Any, cf_name: str = "default")
110109
def expire(self, timestamp: int, prefix: bytes, cf_name: str = "default"):
111110
key = self._serialize_key(timestamp + 1, prefix)
112111

113-
cached = self._update_cache.iter_items(prefix=prefix, cf_name=cf_name)
114-
for cached_key, _ in cached:
112+
cached = self._update_cache.get_updates_for_prefix(
113+
prefix=prefix,
114+
cf_name=cf_name,
115+
)
116+
# Cast to list to avoid RuntimeError: dictionary changed size during iteration
117+
for cached_key in list(cached):
115118
if cached_key < key:
116119
self._update_cache.delete(cached_key, prefix, cf_name=cf_name)
117120

tests/test_quixstreams/test_state/test_transaction.py

+23-22
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,29 @@ def test_get_updates_after_delete(self, cache: PartitionTransactionCache):
534534
cache.delete(key=b"key", prefix=b"prefix", cf_name="cf_name")
535535
assert cache.get_updates(cf_name="cf_name") == {b"prefix": {}}
536536

537+
def test_get_updates_for_prefix_empty(self, cache: PartitionTransactionCache):
538+
assert cache.get_updates_for_prefix(prefix=b"prefix", cf_name="cf_name") == {}
539+
540+
# Delete an item and make sure it's not in "updates"
541+
cache.delete(key=b"key", prefix=b"prefix", cf_name="cf_name")
542+
assert cache.get_updates_for_prefix(prefix=b"prefix", cf_name="cf_name") == {}
543+
544+
def test_get_updates_for_prefix_present(self, cache: PartitionTransactionCache):
545+
cache.set(key=b"key", value=b"value", prefix=b"prefix", cf_name="cf_name")
546+
cache.set(key=b"key", value=b"value", prefix=b"other_prefix", cf_name="cf_name")
547+
cache.set(key=b"key", value=b"value", prefix=b"other", cf_name="other_cf_name")
548+
549+
assert cache.get_updates_for_prefix(prefix=b"prefix", cf_name="cf_name") == {
550+
b"key": b"value"
551+
}
552+
553+
def test_get_updates_for_prefix_after_delete(
554+
self, cache: PartitionTransactionCache
555+
):
556+
cache.set(key=b"key", value=b"value", prefix=b"prefix", cf_name="cf_name")
557+
cache.delete(key=b"key", prefix=b"prefix", cf_name="cf_name")
558+
assert cache.get_updates_for_prefix(prefix=b"prefix", cf_name="cf_name") == {}
559+
537560
def test_get_deletes_empty(self, cache: PartitionTransactionCache):
538561
assert cache.get_deletes(cf_name="cf_name") == set()
539562

@@ -567,25 +590,3 @@ def test_get_deletes_after_set(self, cache: PartitionTransactionCache):
567590
def test_empty(self, action, expected, cache):
568591
action(cache)
569592
assert cache.is_empty() == expected
570-
571-
def test_iter_items(self, cache: PartitionTransactionCache):
572-
cache.set(key=b"key1", value=b"value1", prefix=b"prefix")
573-
cache.set(key=b"key2", value=b"value2", prefix=b"prefix")
574-
cache.set(key=b"key3", value=b"value3", prefix=b"prefix")
575-
cache.set(key=b"key4", value=b"value4", prefix=b"prefix", cf_name="other")
576-
577-
assert cache.iter_items(prefix=b"prefix") == [
578-
(b"key1", b"value1"),
579-
(b"key2", b"value2"),
580-
(b"key3", b"value3"),
581-
]
582-
583-
assert cache.iter_items(prefix=b"prefix", backwards=True) == [
584-
(b"key3", b"value3"),
585-
(b"key2", b"value2"),
586-
(b"key1", b"value1"),
587-
]
588-
589-
assert cache.iter_items(prefix=b"prefix", cf_name="other") == [
590-
(b"key4", b"value4"),
591-
]

0 commit comments

Comments
 (0)