4
4
# This module is part of asyncpg and is released under
5
5
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
6
6
7
+ from __future__ import annotations
7
8
8
9
import asyncio
9
10
import collections
11
+ from collections .abc import Callable
10
12
import enum
11
13
import functools
12
14
import getpass
@@ -764,14 +766,21 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
764
766
765
767
766
768
class TLSUpgradeProto (asyncio .Protocol ):
767
- def __init__ (self , loop , host , port , ssl_context , ssl_is_advisory ):
769
+ def __init__ (
770
+ self ,
771
+ loop : asyncio .AbstractEventLoop ,
772
+ host : str ,
773
+ port : int ,
774
+ ssl_context : ssl_module .SSLContext ,
775
+ ssl_is_advisory : bool ,
776
+ ) -> None :
768
777
self .on_data = _create_future (loop )
769
778
self .host = host
770
779
self .port = port
771
780
self .ssl_context = ssl_context
772
781
self .ssl_is_advisory = ssl_is_advisory
773
782
774
- def data_received (self , data ) :
783
+ def data_received (self , data : bytes ) -> None :
775
784
if data == b'S' :
776
785
self .on_data .set_result (True )
777
786
elif (self .ssl_is_advisory and
@@ -789,15 +798,30 @@ def data_received(self, data):
789
798
'rejected SSL upgrade' .format (
790
799
host = self .host , port = self .port )))
791
800
792
- def connection_lost (self , exc ) :
801
+ def connection_lost (self , exc : typing . Optional [ Exception ]) -> None :
793
802
if not self .on_data .done ():
794
803
if exc is None :
795
804
exc = ConnectionError ('unexpected connection_lost() call' )
796
805
self .on_data .set_exception (exc )
797
806
798
807
799
- async def _create_ssl_connection (protocol_factory , host , port , * ,
800
- loop , ssl_context , ssl_is_advisory = False ):
808
+ _ProctolFactoryR = typing .TypeVar (
809
+ "_ProctolFactoryR" , bound = asyncio .protocols .Protocol
810
+ )
811
+
812
+
813
+ async def _create_ssl_connection (
814
+ # TODO: The return type is a specific combination of subclasses of
815
+ # asyncio.protocols.Protocol that we can't express. For now, having the
816
+ # return type be dependent on signature of the factory is an improvement
817
+ protocol_factory : Callable [[], _ProctolFactoryR ],
818
+ host : str ,
819
+ port : int ,
820
+ * ,
821
+ loop : asyncio .AbstractEventLoop ,
822
+ ssl_context : ssl_module .SSLContext ,
823
+ ssl_is_advisory : bool = False ,
824
+ ) -> typing .Tuple [asyncio .Transport , _ProctolFactoryR ]:
801
825
802
826
tr , pr = await loop .create_connection (
803
827
lambda : TLSUpgradeProto (loop , host , port ,
@@ -817,6 +841,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
817
841
try :
818
842
new_tr = await loop .start_tls (
819
843
tr , pr , ssl_context , server_hostname = host )
844
+ assert new_tr is not None
820
845
except (Exception , asyncio .CancelledError ):
821
846
tr .close ()
822
847
raise
0 commit comments