Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ for channel in rocket.channels_list():
rocket.chat_post_message('good news everyone!', channel='GENERAL', alias='Farnsworth')

# Get channel history
rocket.channels_history('GENERAL', count=5)
rocket.channels_history('GENERAL', max_count=5)
```

### Token-Based Authentication
Expand Down
81 changes: 56 additions & 25 deletions rocketchat_API/APISections/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import itertools
import re

from functools import wraps

from json import JSONDecodeError
from typing import Any

from typing import Any, Callable, Generator

import requests

Expand All @@ -16,7 +14,40 @@
)


def paginated(data_key):
def _paginated_generator(
self,
func: Callable[..., dict[str, Any]],
data_key: str,
first_data: dict[str, Any],
offset: int,
count: int,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Generator[dict[str, Any], None, None]:
"""Inner generator that yields items from paginated API responses."""
data = first_data
while True:
items = data.get(data_key, [])
if not items:
break

yield from items

# If we got fewer items than requested, we've reached the end
if len(items) < count:
break

offset += count
# Call the original function with pagination parameters
data = func(self, *args, offset=offset, count=count, **kwargs)


def paginated(
data_key: str,
) -> Callable[
[Callable[..., dict[str, Any]]],
Callable[..., Generator[dict[str, Any], None, None]],
]:
"""
Decorator that converts a paginated API method into an iterator.

Expand All @@ -28,41 +59,41 @@ def paginated(data_key):
A decorator that wraps the original method to yield items one by one,
automatically handling pagination with offset and count parameters.

Kwargs (handled by the wrapper):
offset: Starting offset for pagination (default: 0)
count: Number of items per page (default: 50)
max_count: Maximum total number of items to return (default: None, returns all)

Example:
@paginated('groups')
def groups_list_all(self, **kwargs):
return self.call_api_get("groups.listAll", kwargs=kwargs)
"""

def decorator(func):
def _generator(self, first_data, offset, count, args, kwargs):
"""Inner generator that yields items from paginated API responses."""
data = first_data
while True:
items = data.get(data_key, [])
if not items:
break

for item in items:
yield item
# Get all groups
list(rocket.groups_list_all())

# If we got fewer items than requested, we've reached the end
if len(items) < count:
break

offset += count
# Call the original function with pagination parameters
data = func(self, *args, offset=offset, count=count, **kwargs)
# Get at most 100 groups
list(rocket.groups_list_all(max_count=100))
"""

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
offset = kwargs.pop("offset", 0)
count = kwargs.pop("count", 50)
max_count = kwargs.pop("max_count", None)

# Call the original function eagerly to propagate any exceptions
first_data = func(self, *args, offset=offset, count=count, **kwargs)

return _generator(self, first_data, offset, count, args, kwargs)
items_gen = _paginated_generator(
self, func, data_key, first_data, offset, count, args, kwargs
)

if max_count is not None:
return itertools.islice(items_gen, max_count)

return items_gen

return wrapper

Expand Down
17 changes: 10 additions & 7 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test_channels_list(logged_rocket):
assert "_id" in channel
assert "name" in channel

iterated_channels_custom = list(logged_rocket.channels_list(count=1))
assert len(iterated_channels_custom) > 0
iterated_channels_custom = list(logged_rocket.channels_list(max_count=1))
assert len(iterated_channels_custom) == 1

for channel in logged_rocket.channels_list():
assert "_id" in channel
Expand All @@ -67,8 +67,8 @@ def test_channels_list_joined(logged_rocket):
assert "_id" in channel
assert "name" in channel

iterated_channels_custom = list(logged_rocket.channels_list_joined(count=1))
assert len(iterated_channels_custom) > 0
iterated_channels_custom = list(logged_rocket.channels_list_joined(max_count=1))
assert len(iterated_channels_custom) == 1

for channel in logged_rocket.channels_list_joined():
assert "_id" in channel
Expand Down Expand Up @@ -96,8 +96,11 @@ def test_channels_history(logged_rocket):

# Test with custom count parameter
iterated_messages_custom = list(
logged_rocket.channels_history(room_id="GENERAL", count=1)
logged_rocket.channels_history(room_id="GENERAL", max_count=1)
)

assert len(iterated_messages_custom) == 1

for message in iterated_messages_custom:
assert "_id" in message

Expand Down Expand Up @@ -384,9 +387,9 @@ def test_channels_members(logged_rocket):

# Test with custom count parameter
iterated_members_custom = list(
logged_rocket.channels_members(room_id="GENERAL", count=1)
logged_rocket.channels_members(room_id="GENERAL", max_count=1)
)
assert len(iterated_members_custom) > 0
assert len(iterated_members_custom) == 1

for member in logged_rocket.channels_members(room_id="GENERAL"):
assert "_id" in member
Expand Down
8 changes: 4 additions & 4 deletions tests/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_groups_list_all(logged_rocket):
assert "_id" in group
assert "name" in group

iterated_groups_custom = list(logged_rocket.groups_list_all(count=1))
assert len(iterated_groups_custom) > 0
iterated_groups_custom = list(logged_rocket.groups_list_all(max_count=1))
assert len(iterated_groups_custom) == 1

for group in logged_rocket.groups_list_all():
assert "_id" in group
Expand Down Expand Up @@ -363,9 +363,9 @@ def test_groups_members(logged_rocket, test_group_name, test_group_id):

# Test with custom count parameter
iterated_members_custom = list(
logged_rocket.groups_members(room_id=test_group_id, count=1)
logged_rocket.groups_members(room_id=test_group_id, max_count=1)
)
assert len(iterated_members_custom) > 0
assert len(iterated_members_custom) == 1

with pytest.raises(RocketMissingParamException):
logged_rocket.groups_members()
Expand Down
135 changes: 135 additions & 0 deletions tests/test_paginated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Any

from rocketchat_API.APISections.base import paginated


class MockAPI:
def __init__(self, total_items: int) -> None:
self.total_items = total_items
self.call_count = 0

@paginated("items")
def get_items(self, **kwargs: Any) -> Any:
self.call_count += 1
offset = kwargs.get("offset", 0)
count = kwargs.get("count", 50)
start = offset
end = min(offset + count, self.total_items)
items = [{"id": i, "name": f"item_{i}"} for i in range(start, end)]
return {"items": items, "success": True}


def test_basic_pagination():
api = MockAPI(total_items=150)
result = list(api.get_items(count=50))

assert len(result) == 150
assert api.call_count == 4


def test_max_count_less_than_total():
api = MockAPI(total_items=150)
result: list[dict[str, Any]] = list(api.get_items(max_count=100))

assert len(result) == 100
first_item: dict[str, Any] = result[0]
last_item: dict[str, Any] = result[99]
assert first_item.get("id") == 0
assert last_item.get("id") == 99


def test_max_count_greater_than_total():
api = MockAPI(total_items=50)
result = list(api.get_items(max_count=100))

assert len(result) == 50


def test_max_count_exact_page_boundary():
api = MockAPI(total_items=150)
result = list(api.get_items(count=50, max_count=100))

assert len(result) == 100
assert api.call_count == 2


def test_max_count_mid_page():
api = MockAPI(total_items=150)
result = list(api.get_items(count=50, max_count=75))

assert len(result) == 75
assert api.call_count == 2


def test_max_count_one():
api = MockAPI(total_items=150)
result: list[dict[str, Any]] = list(api.get_items(max_count=1))

assert len(result) == 1
first_item: dict[str, Any] = result[0]
assert first_item.get("id") == 0
assert api.call_count == 1


def test_max_count_zero_returns_empty():
api = MockAPI(total_items=150)
result = list(api.get_items(max_count=0))

assert len(result) == 0


def test_max_count_none_returns_all():
api = MockAPI(total_items=75)
result = list(api.get_items(count=50))

assert len(result) == 75
assert api.call_count == 2


def test_offset_with_max_count():
api = MockAPI(total_items=150)
result: list[dict[str, Any]] = list(api.get_items(offset=50, max_count=50))

assert len(result) == 50
first_item: dict[str, Any] = result[0]
last_item: dict[str, Any] = result[49]
assert first_item.get("id") == 50
assert last_item.get("id") == 99


def test_custom_count_with_max_count():
api = MockAPI(total_items=100)
result = list(api.get_items(count=10, max_count=25))

assert len(result) == 25
assert api.call_count == 3


def test_generator_behavior():
api = MockAPI(total_items=100)
gen = api.get_items(max_count=10)

assert hasattr(gen, "__iter__")
assert hasattr(gen, "__next__")

items = []
for item in gen:
items.append(item)

assert len(items) == 10


def test_empty_response():
api = MockAPI(total_items=0)
result = list(api.get_items())

assert len(result) == 0
assert api.call_count == 1


def test_empty_response_with_max_count():
api = MockAPI(total_items=0)
result = list(api.get_items(max_count=100))

assert len(result) == 0
assert api.call_count == 1
4 changes: 2 additions & 2 deletions tests/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_rooms_admin_rooms(logged_rocket):
assert "t" in room

# Test with custom count parameter
iterated_rooms_custom = list(logged_rocket.rooms_admin_rooms(count=1))
assert len(iterated_rooms_custom) > 0
iterated_rooms_custom = list(logged_rocket.rooms_admin_rooms(max_count=1))
assert len(iterated_rooms_custom) == 1

rooms_with_filter = list(logged_rocket.rooms_admin_rooms(filter="general"))
assert len(rooms_with_filter) == 1
Expand Down
7 changes: 0 additions & 7 deletions tests/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,6 @@ def test_users_list(logged_rocket):
assert "_id" in user
assert "username" in user

iterated_users_custom = list(logged_rocket.users_list(count=1))
assert len(iterated_users_custom) > 0
assert len(iterated_users_custom) == len(iterated_users)

for user in logged_rocket.users_list():
assert "_id" in user


def test_users_set_status(logged_rocket):
logged_rocket.users_set_status(message="working on it", status="online")
Loading