Skip to content

Commit f0756a5

Browse files
committed
Added stream broker, updated redis.
1 parent df015e5 commit f0756a5

File tree

7 files changed

+79
-51
lines changed

7 files changed

+79
-51
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ warn_return_any = false
5555
[[tool.mypy.overrides]]
5656
module = ['redis']
5757
ignore_missing_imports = true
58+
ignore_errors = true
5859
strict = false
5960

6061
[build-system]

taskiq_redis/redis_backend.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
from typing_extensions import TypeAlias
3535

3636
if TYPE_CHECKING:
37-
_Redis: TypeAlias = Redis[bytes]
38-
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
37+
_Redis: TypeAlias = Redis[bytes] # type: ignore
38+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection] # type: ignore
3939
else:
4040
_Redis: TypeAlias = Redis
4141
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool
@@ -258,7 +258,7 @@ def __init__(
258258
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
259259
and result_px_time are equal zero.
260260
"""
261-
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
261+
self.redis: "RedisCluster" = RedisCluster.from_url(
262262
redis_url,
263263
**connection_kwargs,
264264
)
@@ -275,14 +275,10 @@ def __init__(
275275
),
276276
)
277277
if unavailable_conditions:
278-
raise ExpireTimeMustBeMoreThanZeroError(
279-
"You must select one expire time param and it must be more than zero.",
280-
)
278+
raise ExpireTimeMustBeMoreThanZeroError
281279

282280
if self.result_ex_time and self.result_px_time:
283-
raise DuplicateExpireTimeSelectedError(
284-
"Choose either result_ex_time or result_px_time.",
285-
)
281+
raise DuplicateExpireTimeSelectedError
286282

287283
def _task_name(self, task_id: str) -> str:
288284
if self.prefix_str is None:
@@ -291,7 +287,7 @@ def _task_name(self, task_id: str) -> str:
291287

292288
async def shutdown(self) -> None:
293289
"""Closes redis connection."""
294-
await self.redis.aclose() # type: ignore[attr-defined]
290+
await self.redis.aclose()
295291
await super().shutdown()
296292

297293
async def set_result(
@@ -327,7 +323,7 @@ async def is_result_ready(self, task_id: str) -> bool:
327323
328324
:returns: True if the result is ready else False.
329325
"""
330-
return bool(await self.redis.exists(self._task_name(task_id))) # type: ignore[attr-defined]
326+
return bool(await self.redis.exists(self._task_name(task_id)))
331327

332328
async def get_result(
333329
self,
@@ -344,11 +340,11 @@ async def get_result(
344340
"""
345341
task_name = self._task_name(task_id)
346342
if self.keep_results:
347-
result_value = await self.redis.get( # type: ignore[attr-defined]
343+
result_value = await self.redis.get(
348344
name=task_name,
349345
)
350346
else:
351-
result_value = await self.redis.getdel( # type: ignore[attr-defined]
347+
result_value = await self.redis.getdel(
352348
name=task_name,
353349
)
354350

@@ -400,7 +396,7 @@ async def get_progress(
400396
:param task_id: task's id.
401397
:return: task's TaskProgress instance.
402398
"""
403-
result_value = await self.redis.get( # type: ignore[attr-defined]
399+
result_value = await self.redis.get(
404400
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
405401
)
406402

@@ -465,14 +461,10 @@ def __init__(
465461
),
466462
)
467463
if unavailable_conditions:
468-
raise ExpireTimeMustBeMoreThanZeroError(
469-
"You must select one expire time param and it must be more than zero.",
470-
)
464+
raise ExpireTimeMustBeMoreThanZeroError
471465

472466
if self.result_ex_time and self.result_px_time:
473-
raise DuplicateExpireTimeSelectedError(
474-
"Choose either result_ex_time or result_px_time.",
475-
)
467+
raise DuplicateExpireTimeSelectedError
476468

477469
def _task_name(self, task_id: str) -> str:
478470
if self.prefix_str is None:
@@ -610,4 +602,4 @@ async def get_progress(
610602
async def shutdown(self) -> None:
611603
"""Shutdown sentinel connections."""
612604
for sentinel in self.sentinel.sentinels:
613-
await sentinel.aclose() # type: ignore[attr-defined]
605+
await sentinel.aclose()

taskiq_redis/redis_broker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from typing_extensions import TypeAlias
2929

3030
if TYPE_CHECKING:
31-
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
31+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection] # type: ignore
3232
else:
3333
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool
3434

@@ -122,7 +122,7 @@ async def kick(self, message: BrokerMessage) -> None:
122122
"""
123123
queue_name = message.labels.get("queue_name") or self.queue_name
124124
async with Redis(connection_pool=self.connection_pool) as redis_conn:
125-
await redis_conn.lpush(queue_name, message.message)
125+
await redis_conn.lpush(queue_name, message.message) # type: ignore
126126

127127
async def listen(self) -> AsyncGenerator[bytes, None]:
128128
"""
@@ -137,7 +137,7 @@ async def listen(self) -> AsyncGenerator[bytes, None]:
137137
while True:
138138
try:
139139
async with Redis(connection_pool=self.connection_pool) as redis_conn:
140-
yield (await redis_conn.brpop(self.queue_name))[
140+
yield (await redis_conn.brpop(self.queue_name))[ # type: ignore
141141
redis_brpop_data_position
142142
]
143143
except ConnectionError as exc:
@@ -238,7 +238,7 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]:
238238
self.consumer_name,
239239
{
240240
self.queue_name: ">",
241-
**self.additional_streams,
241+
**self.additional_streams, # type: ignore
242242
},
243243
block=self.block,
244244
noack=False,

taskiq_redis/redis_cluster_broker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
"""
3131
super().__init__()
3232

33-
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
33+
self.redis: "RedisCluster[bytes]" = RedisCluster.from_url( # type: ignore
3434
url=url,
3535
max_connections=max_connection_pool_size,
3636
**connection_kwargs,
@@ -40,7 +40,7 @@ def __init__(
4040

4141
async def shutdown(self) -> None:
4242
"""Closes redis connection pool."""
43-
await self.redis.aclose() # type: ignore[attr-defined]
43+
await self.redis.aclose()
4444
await super().shutdown()
4545

4646

@@ -55,7 +55,7 @@ async def kick(self, message: BrokerMessage) -> None:
5555
5656
:param message: message to append.
5757
"""
58-
await self.redis.lpush(self.queue_name, message.message) # type: ignore[attr-defined]
58+
await self.redis.lpush(self.queue_name, message.message) # type: ignore
5959

6060
async def listen(self) -> AsyncGenerator[bytes, None]:
6161
"""
@@ -68,7 +68,7 @@ async def listen(self) -> AsyncGenerator[bytes, None]:
6868
"""
6969
redis_brpop_data_position = 1
7070
while True:
71-
value = await self.redis.brpop([self.queue_name]) # type: ignore[attr-defined]
71+
value = await self.redis.brpop([self.queue_name]) # type: ignore
7272
yield value[redis_brpop_data_position]
7373

7474

@@ -155,7 +155,7 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]:
155155
self.consumer_name,
156156
{
157157
self.queue_name: ">",
158-
**self.additional_streams,
158+
**self.additional_streams, # type: ignore
159159
},
160160
block=self.block,
161161
noack=False,

taskiq_redis/redis_sentinel_broker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from typing_extensions import TypeAlias
2828

2929
if TYPE_CHECKING:
30-
_Redis: TypeAlias = Redis[bytes]
30+
_Redis: TypeAlias = Redis[bytes] # type: ignore
3131
else:
3232
_Redis: TypeAlias = Redis
3333

@@ -117,7 +117,7 @@ async def kick(self, message: BrokerMessage) -> None:
117117
"""
118118
queue_name = message.labels.get("queue_name") or self.queue_name
119119
async with self._acquire_master_conn() as redis_conn:
120-
await redis_conn.lpush(queue_name, message.message)
120+
await redis_conn.lpush(queue_name, message.message) # type: ignore
121121

122122
async def listen(self) -> AsyncGenerator[bytes, None]:
123123
"""
@@ -131,7 +131,7 @@ async def listen(self) -> AsyncGenerator[bytes, None]:
131131
redis_brpop_data_position = 1
132132
async with self._acquire_master_conn() as redis_conn:
133133
while True:
134-
yield (await redis_conn.brpop(self.queue_name))[
134+
yield (await redis_conn.brpop(self.queue_name))[ # type: ignore
135135
redis_brpop_data_position
136136
]
137137

@@ -226,7 +226,7 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]:
226226
self.consumer_name,
227227
{
228228
self.queue_name: ">",
229-
**self.additional_streams,
229+
**self.additional_streams, # type: ignore
230230
},
231231
block=self.block,
232232
noack=False,

taskiq_redis/schedule_source.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from typing_extensions import TypeAlias
2222

2323
if TYPE_CHECKING:
24-
_Redis: TypeAlias = Redis[bytes]
25-
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
24+
_Redis: TypeAlias = Redis[bytes] # type: ignore
25+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection] # type: ignore
2626
else:
2727
_Redis: TypeAlias = Redis
2828
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool
@@ -140,7 +140,7 @@ def __init__(
140140
**connection_kwargs: Any,
141141
) -> None:
142142
self.prefix = prefix
143-
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
143+
self.redis: "RedisCluster" = RedisCluster.from_url(
144144
url,
145145
**connection_kwargs,
146146
)
@@ -150,7 +150,7 @@ def __init__(
150150

151151
async def delete_schedule(self, schedule_id: str) -> None:
152152
"""Remove schedule by id."""
153-
await self.redis.delete(f"{self.prefix}:{schedule_id}") # type: ignore[attr-defined]
153+
await self.redis.delete(f"{self.prefix}:{schedule_id}")
154154

155155
async def add_schedule(self, schedule: ScheduledTask) -> None:
156156
"""
@@ -159,7 +159,7 @@ async def add_schedule(self, schedule: ScheduledTask) -> None:
159159
:param schedule: schedule to add.
160160
:param schedule_id: schedule id.
161161
"""
162-
await self.redis.set( # type: ignore[attr-defined]
162+
await self.redis.set(
163163
f"{self.prefix}:{schedule.schedule_id}",
164164
self.serializer.dumpb(model_dump(schedule)),
165165
)
@@ -173,8 +173,8 @@ async def get_schedules(self) -> List[ScheduledTask]:
173173
:return: list of schedules.
174174
"""
175175
schedules = []
176-
async for key in self.redis.scan_iter(f"{self.prefix}:*"): # type: ignore[attr-defined]
177-
raw_schedule = await self.redis.get(key) # type: ignore[attr-defined]
176+
async for key in self.redis.scan_iter(f"{self.prefix}:*"):
177+
raw_schedule = await self.redis.get(key)
178178
parsed_schedule = model_validate(
179179
ScheduledTask,
180180
self.serializer.loadb(raw_schedule),
@@ -189,7 +189,7 @@ async def post_send(self, task: ScheduledTask) -> None:
189189

190190
async def shutdown(self) -> None:
191191
"""Shut down the schedule source."""
192-
await self.redis.aclose() # type: ignore[attr-defined]
192+
await self.redis.aclose()
193193

194194

195195
class RedisSentinelScheduleSource(ScheduleSource):
@@ -288,4 +288,4 @@ async def post_send(self, task: ScheduledTask) -> None:
288288
async def shutdown(self) -> None:
289289
"""Shut down the schedule source."""
290290
for sentinel in self.sentinel.sentinels:
291-
await sentinel.aclose() # type: ignore[attr-defined]
291+
await sentinel.aclose()

tests/test_broker.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
RedisStreamClusterBroker,
1515
RedisStreamSentinelBroker,
1616
)
17+
from taskiq_redis.redis_broker import RedisStreamBroker
1718

1819

1920
def test_no_url_should_raise_typeerror() -> None:
@@ -129,6 +130,44 @@ async def test_list_queue_broker(
129130
await broker.shutdown()
130131

131132

133+
@pytest.mark.anyio
134+
async def test_stream_broker(
135+
valid_broker_message: BrokerMessage,
136+
redis_url: str,
137+
) -> None:
138+
"""
139+
Test that messages are published and read correctly by ListQueueBroker.
140+
141+
We create two workers that listen and send a message to them.
142+
Expect only one worker to receive the same message we sent.
143+
"""
144+
broker = RedisStreamBroker(
145+
url=redis_url,
146+
queue_name=uuid.uuid4().hex,
147+
consumer_group_name=uuid.uuid4().hex,
148+
)
149+
await broker.startup()
150+
151+
worker1_task = asyncio.create_task(get_message(broker))
152+
worker2_task = asyncio.create_task(get_message(broker))
153+
154+
await broker.kick(valid_broker_message)
155+
156+
await asyncio.wait(
157+
[worker1_task, worker2_task],
158+
return_when=asyncio.FIRST_COMPLETED,
159+
)
160+
161+
assert worker1_task.done() != worker2_task.done()
162+
message = worker1_task.result() if worker1_task.done() else worker2_task.result()
163+
assert isinstance(message, AckableMessage)
164+
assert message.data == valid_broker_message.message
165+
await message.ack() # type: ignore
166+
worker1_task.cancel()
167+
worker2_task.cancel()
168+
await broker.shutdown()
169+
170+
132171
@pytest.mark.anyio
133172
async def test_list_queue_broker_max_connections(
134173
valid_broker_message: BrokerMessage,
@@ -196,14 +235,13 @@ async def test_stream_cluster_broker(
196235
consumer_group_name=uuid.uuid4().hex,
197236
)
198237
await broker.startup()
238+
199239
worker_task = asyncio.create_task(get_message(broker))
200-
await asyncio.sleep(0.3)
201240

202241
await broker.kick(valid_broker_message)
203-
await asyncio.sleep(0.3)
204242

205-
assert worker_task.done()
206-
result = worker_task.result()
243+
result = await worker_task
244+
207245
assert isinstance(result, AckableMessage)
208246
assert result.data == valid_broker_message.message
209247
await result.ack() # type: ignore
@@ -291,13 +329,10 @@ async def test_streams_sentinel_broker(
291329
)
292330
await broker.startup()
293331
worker_task = asyncio.create_task(get_message(broker))
294-
await asyncio.sleep(0.3)
295332

296333
await broker.kick(valid_broker_message)
297-
await asyncio.sleep(0.3)
298334

299-
assert worker_task.done()
300-
result = worker_task.result()
335+
result = await worker_task
301336
assert isinstance(result, AckableMessage)
302337
assert result.data == valid_broker_message.message
303338
await result.ack() # type: ignore

0 commit comments

Comments
 (0)