diff --git a/tests/test_client.py b/tests/test_client.py index d4a6e0d..687891e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -459,3 +459,18 @@ def test_scan_url(httpserver): analysis = client.scan_url('https://www.dummy.url') assert analysis.type == 'analysis' + +def test_user_headers(httpserver): + user_headers = {'foo': 'bar'} + + client = Client( + 'dummy_api_key', + host='http://' + httpserver.host + ':' + str(httpserver.port), + timeout=500, headers=user_headers) + + headers = client._get_session().headers + + assert 'X-Apikey' in headers + assert 'Accept-Encoding' in headers + assert 'User-Agent' in headers + assert 'foo' in headers diff --git a/vt/client.py b/vt/client.py index 3973788..0d37146 100644 --- a/vt/client.py +++ b/vt/client.py @@ -182,16 +182,18 @@ class Client: a request to timeout (300 by default). :param proxy: A string indicating the proxy to use for requests made by the client (None by default). + :param headers: Dict of headers defined by the user. :type apikey: str :type agent: str :type host: str :type trust_env: bool :type timeout: int :type proxy: str + :type headers: dict """ def __init__(self, apikey, agent="unknown", host=None, trust_env=False, - timeout=300, proxy=None): + timeout=300, proxy=None, headers=None): """Initialize the client with the provided API key.""" if not isinstance(apikey, str): @@ -207,6 +209,7 @@ def __init__(self, apikey, agent="unknown", host=None, trust_env=False, self._trust_env = trust_env self._timeout = timeout self._proxy = proxy + self._user_headers = headers def _full_url(self, path, *args): try: @@ -219,13 +222,19 @@ def _full_url(self, path, *args): def _get_session(self): if not self._session: + headers = { + 'X-Apikey': self._apikey, + 'Accept-Encoding': 'gzip', + 'User-Agent': _USER_AGENT_FMT.format_map({ + 'agent': self._agent, 'version': __version__}) + } + + if self._user_headers: + headers.update(self._user_headers) + self._session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl=False), - headers={ - 'X-Apikey': self._apikey, - 'Accept-Encoding': 'gzip', - 'User-Agent': _USER_AGENT_FMT.format_map({ - 'agent': self._agent, 'version': __version__})}, + headers=headers, trust_env=self._trust_env, timeout=aiohttp.ClientTimeout(total=self._timeout)) return self._session diff --git a/vt/version.py b/vt/version.py index 2d7893e..deea98b 100644 --- a/vt/version.py +++ b/vt/version.py @@ -1 +1 @@ -__version__ = '0.13.0' +__version__ = '0.13.1'