@@ -53,7 +53,7 @@ def connect(host="localhost", user=None, password="",
5353            connect_timeout = None , read_default_group = None ,
5454            autocommit = False , echo = False ,
5555            local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
56-             program_name = '' , server_public_key = None ):
56+             program_name = '' , server_public_key = None ,  implicit_tls = False ):
5757    """See connections.Connection.__init__() for information about 
5858    defaults.""" 
5959    coro  =  _connect (host = host , user = user , password = password , db = db ,
@@ -66,7 +66,8 @@ def connect(host="localhost", user=None, password="",
6666                    read_default_group = read_default_group ,
6767                    autocommit = autocommit , echo = echo ,
6868                    local_infile = local_infile , loop = loop , ssl = ssl ,
69-                     auth_plugin = auth_plugin , program_name = program_name )
69+                     auth_plugin = auth_plugin , program_name = program_name ,
70+                     implicit_tls = implicit_tls )
7071    return  _ConnectionContextManager (coro )
7172
7273
@@ -142,7 +143,7 @@ def __init__(self, host="localhost", user=None, password="",
142143                 connect_timeout = None , read_default_group = None ,
143144                 autocommit = False , echo = False ,
144145                 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
145-                  program_name = '' , server_public_key = None ):
146+                  program_name = '' , server_public_key = None ,  implicit_tls = False ):
146147        """ 
147148        Establish a connection to the MySQL database. Accepts several 
148149        arguments: 
@@ -184,6 +185,9 @@ def __init__(self, host="localhost", user=None, password="",
184185            handshaking with MySQL. (omitted by default) 
185186        :param server_public_key: SHA256 authentication plugin public 
186187            key value. 
188+         :param implicit_tls: Establish TLS immediately, skipping non-TLS 
189+             preamble before upgrading to TLS. 
190+             (default: False) 
187191        :param loop: asyncio loop 
188192        """ 
189193        self ._loop  =  loop  or  asyncio .get_event_loop ()
@@ -218,6 +222,7 @@ def __init__(self, host="localhost", user=None, password="",
218222        self ._auth_plugin_used  =  "" 
219223        self ._secure  =  False 
220224        self .server_public_key  =  server_public_key 
225+         self ._implicit_tls  =  implicit_tls 
221226        self .salt  =  None 
222227
223228        from  . import  __version__ 
@@ -241,7 +246,7 @@ def __init__(self, host="localhost", user=None, password="",
241246            self .use_unicode  =  use_unicode 
242247
243248        self ._ssl_context  =  ssl 
244-         if  ssl :
249+         if  ssl   and   not   implicit_tls :
245250            client_flag  |=  CLIENT .SSL 
246251
247252        self ._encoding  =  charset_by_name (self ._charset ).encoding 
@@ -536,7 +541,8 @@ async def _connect(self):
536541
537542            self ._next_seq_id  =  0 
538543
539-             await  self ._get_server_information ()
544+             if  not  self ._implicit_tls :
545+                 await  self ._get_server_information ()
540546            await  self ._request_authentication ()
541547
542548            self .connected_time  =  self ._loop .time ()
@@ -727,7 +733,8 @@ async def _execute_command(self, command, sql):
727733
728734    async  def  _request_authentication (self ):
729735        # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse 
730-         if  int (self .server_version .split ('.' , 1 )[0 ]) >=  5 :
736+         # FIXME: change this before merge 
737+         if  self ._implicit_tls  or  int (self .server_version .split ('.' , 1 )[0 ]) >=  5 :
731738            self .client_flag  |=  CLIENT .MULTI_RESULTS 
732739
733740        if  self .user  is  None :
@@ -737,8 +744,10 @@ async def _request_authentication(self):
737744        data_init  =  struct .pack ('<iIB23s' , self .client_flag , MAX_PACKET_LEN ,
738745                                charset_id , b'' )
739746
740-         if  self ._ssl_context  and  self .server_capabilities  &  CLIENT .SSL :
741-             self .write_packet (data_init )
747+         if  self ._ssl_context  and  \
748+                 (self ._implicit_tls  or  self .server_capabilities  &  CLIENT .SSL ):
749+             if  not  self ._implicit_tls :
750+                 self .write_packet (data_init )
742751
743752            # Stop sending events to data_received 
744753            self ._writer .transport .pause_reading ()
@@ -760,6 +769,9 @@ async def _request_authentication(self):
760769                server_hostname = self ._host 
761770            )
762771
772+             if  self ._implicit_tls :
773+                 await  self ._get_server_information ()
774+ 
763775            self ._secure  =  True 
764776
765777        if  isinstance (self .user , str ):
0 commit comments