Skip to content

Commit

Permalink
Change iterator implementation so it's not based on generators (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosjcabello authored Aug 7, 2021
1 parent 9bd3d27 commit 654a194
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 93 deletions.
4 changes: 4 additions & 0 deletions tests/test_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_next(httpserver, iterator_response):
it = client.iterator('/dummy_collection/foo', limit=10, batch_size=3)
assert next(it).id == 'dummy_id_1'
assert next(it).id == 'dummy_id_2'
assert it._batch_cursor == 2

# iteration must start right where the next stayed
last = None
Expand All @@ -72,6 +73,7 @@ def test_next(httpserver, iterator_response):

assert last.id == 'dummy_id_4'
assert it._count == 4
assert it._batch_cursor == 1

with pytest.raises(StopIteration):
# there shouldn't be more available elements after the for loop
Expand Down Expand Up @@ -113,6 +115,7 @@ async def test_anext(httpserver, iterator_response):
async with new_client(httpserver) as client:
it = client.iterator('/dummy_collection/foo', limit=10, batch_size=3)
assert (await it.__anext__()).id == 'dummy_id_1'
assert it._batch_cursor == 1

# iteration must start right where the next stayed
last, i = None, 0
Expand All @@ -123,6 +126,7 @@ async def test_anext(httpserver, iterator_response):

assert last.id == 'dummy_id_4'
assert it._count == 4
assert it._batch_cursor == 1

with pytest.raises(StopAsyncIteration):
# there shouldn't be more available elements after the for loop
Expand Down
56 changes: 23 additions & 33 deletions vt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from .error import APIError
from .feed import Feed
from .object import Object
from .iterator import Iterator
from .object import Object
from .utils import make_sync
from .version import __version__


Expand All @@ -42,17 +43,6 @@
_USER_AGENT_FMT = '{agent}; vtpy {version}; gzip'


def _make_sync(future):
"""Utility function that waits for an async call, making it sync."""
try:
event_loop = asyncio.get_event_loop()
except RuntimeError:
# Generate an event loop if there isn't any.
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
return event_loop.run_until_complete(future)


def url_id(url):
"""Generates the object ID for an URL.
Expand Down Expand Up @@ -102,7 +92,7 @@ async def read_async(self):
return await self._aiohttp_resp.read()

def read(self):
return _make_sync(self.read_async())
return make_sync(self.read_async())

async def json_async(self):
if self.headers.get('Transfer-encoding') == 'chunked':
Expand All @@ -112,7 +102,7 @@ async def json_async(self):
return await self._aiohttp_resp.json()

def json(self):
return _make_sync(self.json_async())
return make_sync(self.json_async())

async def text_async(self):
if self.headers.get('Transfer-encoding') == 'chunked':
Expand All @@ -122,7 +112,7 @@ async def text_async(self):
return await self._aiohttp_resp.text()

def text(self):
return _make_sync(self.text_async())
return make_sync(self.text_async())


class StreamReader:
Expand Down Expand Up @@ -151,31 +141,31 @@ async def read_async(self, n=-1):
return await self._aiohttp_stream_reader.read(n)

def read(self, n=-1):
return _make_sync(self.read_async(n))
return make_sync(self.read_async(n))

async def readany_async(self):
return await self._aiohttp_stream_reader.readany()

def readany(self):
return _make_sync(self.readany_async())
return make_sync(self.readany_async())

async def readexactly_async(self, n):
return await self._aiohttp_stream_reader.readexactly(n)

def readexactly(self, n):
return _make_sync(self.readexactly_async(n))
return make_sync(self.readexactly_async(n))

async def readline_async(self):
return await self._aiohttp_stream_reader.readline()

def readline(self):
return _make_sync(self.readline_async())
return make_sync(self.readline_async())

async def readchunk_async(self):
return await self._aiohttp_stream_reader.readchunk()

def readchunk(self):
return _make_sync(self.readchunk_async())
return make_sync(self.readchunk_async())


class Client:
Expand Down Expand Up @@ -274,7 +264,7 @@ def close(self):
When the client is not needed anymore it should be closed for releasing
resources like TCP connections.
"""
return _make_sync(self.close_async( ))
return make_sync(self.close_async())

def delete(self, path, *path_args):
"""Sends a DELETE request to a given API endpoint.
Expand All @@ -285,7 +275,7 @@ def delete(self, path, *path_args):
:type path: str
:returns: An instance of :class:`ClientResponse`.
"""
return _make_sync(self.delete_async(path, *path_args))
return make_sync(self.delete_async(path, *path_args))

async def delete_async(self, path, *path_args):
"""Like :func:`delete` but returns a coroutine."""
Expand All @@ -303,7 +293,7 @@ def download_file(self, hash, file):
:type hash: str
:type file: file-like object
"""
return _make_sync(self.download_file_async(hash, file))
return make_sync(self.download_file_async(hash, file))

async def download_file_async(self, hash, file):
"""Like :func:`download_file` but returns a coroutine."""
Expand Down Expand Up @@ -349,7 +339,7 @@ def get(self, path, *path_args, params=None):
:type params: dict
:returns: An instance of :class:`ClientResponse`.
"""
return _make_sync(self.get_async(path, *path_args, params=params))
return make_sync(self.get_async(path, *path_args, params=params))

async def get_async(self, path, *path_args, params=None):
"""Like :func:`get` but returns a coroutine."""
Expand Down Expand Up @@ -380,7 +370,7 @@ def get_data(self, path, *path_args, params=None):
dict, list, string or some other Python type, depending on the endpoint
called.
"""
return _make_sync(self.get_data_async(path, *path_args, params=params))
return make_sync(self.get_data_async(path, *path_args, params=params))

async def get_data_async(self, path, *path_args, params=None):
"""Like :func:`get_data` but returns a coroutine."""
Expand Down Expand Up @@ -423,7 +413,7 @@ def get_json(self, path, *path_args, params=None):
:returns:
A dictionary with the backend's response.
"""
return _make_sync(self.get_json_async(path, *path_args, params=params))
return make_sync(self.get_json_async(path, *path_args, params=params))

async def get_json_async(self, path, *path_args, params=None):
"""Like :func:`get_json` but returns a coroutine."""
Expand All @@ -447,7 +437,7 @@ def get_object(self, path, *path_args, params=None):
:returns:
An instance of :class:`Object`.
"""
return _make_sync(self.get_object_async(path, *path_args, params=params))
return make_sync(self.get_object_async(path, *path_args, params=params))

async def get_object_async(self, path, *path_args, params=None):
"""Like :func:`get_object` but returns a coroutine."""
Expand All @@ -469,7 +459,7 @@ def patch(self, path, *path_args, data=None):
:type data: A string or bytes
:returns: An instance of :class:`ClientResponse`.
"""
return _make_sync(self.patch_async(path, *path_args, data))
return make_sync(self.patch_async(path, *path_args, data))

async def patch_async(self, path, *path_args, data=None):
"""Like :func:`patch` but returns a coroutine."""
Expand All @@ -493,7 +483,7 @@ def patch_object(self, path, *path_args, obj):
:returns: An instance of :class:`Object` representing the same object after
the changes has been applied.
"""
return _make_sync(self.patch_object_async(path, *path_args, obj=obj))
return make_sync(self.patch_object_async(path, *path_args, obj=obj))

async def patch_object_async(self, path, *path_args, obj):
"""Like :func:`patch_object` but returns a coroutine."""
Expand All @@ -516,7 +506,7 @@ def post(self, path, *path_args, data=None):
:type data: A string or bytes
:returns: An instance of :class:`ClientResponse`.
"""
return _make_sync(self.post_async(path, *path_args, data=data))
return make_sync(self.post_async(path, *path_args, data=data))

async def post_async(self, path, *path_args, data=None):
"""Like :func:`post` but returns a coroutine."""
Expand All @@ -540,7 +530,7 @@ def post_object(self, path, *path_args, obj):
:type obj: :class:`Object`
:returns: An instance of :class:`Object` representing the new object.
"""
return _make_sync(self.post_object_async(path, *path_args, obj=obj))
return make_sync(self.post_object_async(path, *path_args, obj=obj))

async def post_object_async(self, path, *path_args, obj):
"""Like :func:`post_object` but returns a coroutine."""
Expand Down Expand Up @@ -588,7 +578,7 @@ def scan_file(self, file, wait_for_completion=False):
:type wait_for_completion: bool
:returns: An instance of :class:`Object` of analysis type.
"""
return _make_sync(self.scan_file_async(
return make_sync(self.scan_file_async(
file, wait_for_completion=wait_for_completion))

async def scan_file_async(self, file, wait_for_completion=False):
Expand Down Expand Up @@ -635,7 +625,7 @@ def scan_url(self, url, wait_for_completion=False):
:type wait_for_completion: bool
:returns: An instance of :class:`Object` of analysis type.
"""
return _make_sync(self.scan_url_async(
return make_sync(self.scan_url_async(
url, wait_for_completion=wait_for_completion))

async def scan_url_async(self, url, wait_for_completion=False):
Expand Down
74 changes: 15 additions & 59 deletions vt/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


from .object import Object

from .utils import make_sync

__all__ = ['Iterator']

Expand Down Expand Up @@ -109,70 +109,26 @@ async def _get_batch_async(self, batch_cursor=0):
self._path, params=self._build_params())
return self._parse_response(json_resp, batch_cursor)

def _get_batch(self, batch_cursor=0):
json_resp = self._client.get_json(
self._path, params=self._build_params())
return self._parse_response(json_resp, batch_cursor)

def _iterate(self):
if len(self._items) == 0:
self._items, self._server_cursor = self._get_batch()
self._batch_cursor = 0
item = self._items.pop(0)
self._count += 1
self._batch_cursor += 1
return Object.from_dict(item)

async def _aiterate(self):
if len(self._items) == 0:
self._items, self._server_cursor = await self._get_batch_async()
self._batch_cursor = 0
item = self._items.pop(0)
self._count += 1
self._batch_cursor += 1
return Object.from_dict(item)

def __iter__(self):
if not self._items and self._count == 0: # iter called before next
self._items, self._server_cursor = self._get_batch()
if self._limit:
while (self._items or self._server_cursor) and self._count < self._limit:
yield self._iterate()
else:
while (self._items or self._server_cursor):
yield self._iterate()

async def __aiter__(self):
if not self._items and self._count == 0: # iter called before next
self._items, self._server_cursor = await self._get_batch_async()
if self._limit:
while (self._items or self._server_cursor) and self._count < self._limit:
yield await self._aiterate()
else:
while self._items or self._server_cursor:
yield await self._aiterate()
return self

def __aiter__(self):
return self

def __next__(self):
if not self._items and self._count == 0: # next is called before iter
self._items, self._server_cursor = self._get_batch()
if self._limit:
if (not self._items and self._count > 0) or self._count >= self._limit:
raise StopIteration()
elif (not self._items and self._count > 0):
raise StopIteration()
item = self._items.pop(0)
self._count += 1
self._batch_cursor += 1
return Object.from_dict(item)
try:
return make_sync(self.__anext__())
except StopAsyncIteration:
raise StopIteration()

async def __anext__(self):
if not self._items and self._count == 0: # next is called before iter
if not self._items and (self._server_cursor or self._count == 0):
self._items, self._server_cursor = await self._get_batch_async()
if self._limit:
if (not self._items and self._count > 0) or self._count >= self._limit:
raise StopAsyncIteration()
elif (not self._items and self._count > 0):
raise StopAsyncIteration()
self._batch_cursor = 0
if self._limit and self._count == self._limit:
raise StopAsyncIteration()
if not self._items and not self._server_cursor:
raise StopAsyncIteration()
item = self._items.pop(0)
self._count += 1
self._batch_cursor += 1
Expand Down
26 changes: 26 additions & 0 deletions vt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright © 2019 The vt-py authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio


def make_sync(future):
"""Utility function that waits for an async call, making it sync."""
try:
event_loop = asyncio.get_event_loop()
except RuntimeError:
# Generate an event loop if there isn't any.
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
return event_loop.run_until_complete(future)
2 changes: 1 addition & 1 deletion vt/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.7.1'
__version__ = '0.7.2'

0 comments on commit 654a194

Please sign in to comment.