Skip to content

Commit a273e0e

Browse files
authored
Add typing to two objects in connection_utils (#1198)
1 parent bae282e commit a273e0e

File tree

1 file changed

+30
-5
lines changed

1 file changed

+30
-5
lines changed

asyncpg/connect_utils.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
# This module is part of asyncpg and is released under
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

7+
from __future__ import annotations
78

89
import asyncio
910
import collections
11+
from collections.abc import Callable
1012
import enum
1113
import functools
1214
import getpass
@@ -764,14 +766,21 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
764766

765767

766768
class TLSUpgradeProto(asyncio.Protocol):
767-
def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
769+
def __init__(
770+
self,
771+
loop: asyncio.AbstractEventLoop,
772+
host: str,
773+
port: int,
774+
ssl_context: ssl_module.SSLContext,
775+
ssl_is_advisory: bool,
776+
) -> None:
768777
self.on_data = _create_future(loop)
769778
self.host = host
770779
self.port = port
771780
self.ssl_context = ssl_context
772781
self.ssl_is_advisory = ssl_is_advisory
773782

774-
def data_received(self, data):
783+
def data_received(self, data: bytes) -> None:
775784
if data == b'S':
776785
self.on_data.set_result(True)
777786
elif (self.ssl_is_advisory and
@@ -789,15 +798,30 @@ def data_received(self, data):
789798
'rejected SSL upgrade'.format(
790799
host=self.host, port=self.port)))
791800

792-
def connection_lost(self, exc):
801+
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
793802
if not self.on_data.done():
794803
if exc is None:
795804
exc = ConnectionError('unexpected connection_lost() call')
796805
self.on_data.set_exception(exc)
797806

798807

799-
async def _create_ssl_connection(protocol_factory, host, port, *,
800-
loop, ssl_context, ssl_is_advisory=False):
808+
_ProctolFactoryR = typing.TypeVar(
809+
"_ProctolFactoryR", bound=asyncio.protocols.Protocol
810+
)
811+
812+
813+
async def _create_ssl_connection(
814+
# TODO: The return type is a specific combination of subclasses of
815+
# asyncio.protocols.Protocol that we can't express. For now, having the
816+
# return type be dependent on signature of the factory is an improvement
817+
protocol_factory: Callable[[], _ProctolFactoryR],
818+
host: str,
819+
port: int,
820+
*,
821+
loop: asyncio.AbstractEventLoop,
822+
ssl_context: ssl_module.SSLContext,
823+
ssl_is_advisory: bool = False,
824+
) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]:
801825

802826
tr, pr = await loop.create_connection(
803827
lambda: TLSUpgradeProto(loop, host, port,
@@ -817,6 +841,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
817841
try:
818842
new_tr = await loop.start_tls(
819843
tr, pr, ssl_context, server_hostname=host)
844+
assert new_tr is not None
820845
except (Exception, asyncio.CancelledError):
821846
tr.close()
822847
raise

0 commit comments

Comments
 (0)