diff --git a/aiocassandra.py b/aiocassandra.py index bf53af9..190d17f 100644 --- a/aiocassandra.py +++ b/aiocassandra.py @@ -19,7 +19,7 @@ class _Paginator: - def __init__(self, request, *, executor, loop): + def __init__(self, request, *, executor, loop, max_in_memory_pages=None): self.cassandra_fut = None self._request = request @@ -30,10 +30,35 @@ def __init__(self, request, *, executor, loop): self._deque = deque() self._exc = None self._drain_event = asyncio.Event(loop=loop) + self._no_fetching_page = asyncio.Event(loop=loop) self._finish_event = asyncio.Event(loop=loop) self._exit_event = Event() self.__pages = set() + self._max_in_memory_pages = max_in_memory_pages + self._page_size = None + + def _start_fetching_next_page(self): + self._no_fetching_page.clear() + _fn = self.cassandra_fut.start_fetching_next_page + fut = self._loop.run_in_executor(self._executor, _fn) + self.__pages.add(fut) + fut.add_done_callback(self.__pages.remove) + + def _maybe_start_prefetch_next_page(self): + if self._finish_event.is_set() or not self._no_fetching_page.is_set(): + return + + if not self.cassandra_fut.has_more_pages: + self._finish_event.set() + return + + if self._max_in_memory_pages is None: + pass + elif len(self._deque) > self._page_size * (self._max_in_memory_pages - 1): + return + + self._start_fetching_next_page() def _handle_page(self, rows): if self._exit_event.is_set(): @@ -42,19 +67,15 @@ def _handle_page(self, rows): 'Paginator is closed, skipping new %i records', _len) return + if self._page_size is None: + self._page_size = len(rows) + for row in rows: self._deque.append(row) + self._loop.call_soon_threadsafe(self._no_fetching_page.set) self._loop.call_soon_threadsafe(self._drain_event.set) - - if self.cassandra_fut.has_more_pages: - _fn = self.cassandra_fut.start_fetching_next_page - fut = self._loop.run_in_executor(self._executor, _fn) - self.__pages.add(fut) - fut.add_done_callback(self.__pages.remove) - return - - self._loop.call_soon_threadsafe(self._finish_event.set) + self._loop.call_soon_threadsafe(self._maybe_start_prefetch_next_page) def _handle_err(self, exc): self._exc = exc @@ -102,8 +123,11 @@ async def _paginator(self): if self._exc is not None: raise self._exc + self._maybe_start_prefetch_next_page() + while self._deque: await yield_(self._deque.popleft()) + self._maybe_start_prefetch_next_page() await asyncio.wait( ( @@ -153,12 +177,13 @@ async def execute_future(self, *args, **kwargs): return await asyncio_fut -def execute_futures(self, *args, **kwargs): +def execute_futures(self, *args, max_in_memory_pages=None, **kwargs): _request = partial(self.execute_async, *args, **kwargs) return _Paginator( _request, executor=self._asyncio_executor, - loop=self._asyncio_loop + loop=self._asyncio_loop, + max_in_memory_pages=max_in_memory_pages ) diff --git a/tests/test_aiocassandra.py b/tests/test_aiocassandra.py index 83ed093..4eacd00 100755 --- a/tests/test_aiocassandra.py +++ b/tests/test_aiocassandra.py @@ -144,6 +144,25 @@ async def test_execute_futures_simple_statement(cassandra): assert len(ret) != 0 +@pytest.mark.asyncio +async def test_execute_futures_simple_statement_limit_pages(cassandra): + cql = 'SELECT * FROM system.size_estimates LIMIT 50;' + statement = SimpleStatement(cql, fetch_size=10) + + ret = [] + + async with cassandra.execute_futures(statement, max_in_memory_pages=3) as paginator: + await asyncio.sleep(0.5) # wait for fetching pages + assert len(paginator._deque) == 30 + async for row in paginator: + await asyncio.sleep(0.2) # slow down consumer + assert isinstance(row, tuple) + assert len(paginator._deque) <= 30 + ret.append(row) + + assert len(ret) == 50 + + @pytest.mark.asyncio async def test_execute_futures_break(cassandra): cql = 'SELECT * FROM system.size_estimates;'