diff --git a/tests/test_client.py b/tests/test_client.py index a2ce5a8..25d1da1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -15,6 +15,7 @@ import datetime import io +import functools import json import pickle @@ -23,8 +24,10 @@ from vt import Client from vt import Object +from tests import wsgi_app -def new_client(httpserver): + +def new_client(httpserver, unused_apikey=''): return Client( "dummy_api_key", host="http://" + httpserver.host + ":" + str(httpserver.port), @@ -480,3 +483,26 @@ def test_user_headers(httpserver): assert "Accept-Encoding" in headers assert "User-Agent" in headers assert "foo" in headers + + +def test_wsgi_app(httpserver, monkeypatch): + app = wsgi_app.app + app.config.update({'TESTING': True}) + client = app.test_client() + expected_response = {'data': { + 'id': 'google.com', + 'type': 'domain', + 'attributes': {'foo': 'foo'}, + }} + + + httpserver.expect_request( + "/api/v3/domains/google.com", method="GET", + headers={"X-Apikey": "dummy_api_key"} + ).respond_with_json(expected_response) + monkeypatch.setattr( + 'tests.wsgi_app.vt.Client', functools.partial(new_client, httpserver) + ) + response = client.get('/') + assert response.status_code == 200 + assert response.json == expected_response diff --git a/tests/wsgi_app.py b/tests/wsgi_app.py new file mode 100644 index 0000000..c91c247 --- /dev/null +++ b/tests/wsgi_app.py @@ -0,0 +1,12 @@ +import flask +import os +import vt + +app = flask.Flask(__name__) + + +@app.get('/') +def home(): + with vt.Client(os.getenv('VT_APIKEY')) as c: + g = c.get('/domains/google.com') + return g.json() diff --git a/vt/client.py b/vt/client.py index 4fc7c5f..f5ed657 100644 --- a/vt/client.py +++ b/vt/client.py @@ -238,7 +238,17 @@ def __init__( if connector is not None: self._connector = connector else: - self._connector = aiohttp.TCPConnector(ssl=self._verify_ssl) + # the TCPConnector class expects to be instantiated inside a event loop. + # If there is none, create one. + try: + event_loop = asyncio.get_event_loop() + except RuntimeError: + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + self._connector = aiohttp.TCPConnector( + ssl=self._verify_ssl, loop=event_loop + ) def _full_url(self, path:str, *args: typing.Any) -> str: try: