Skip to content

Commit 6dd7caa

Browse files
committed
Add implicit_tls connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL
fixes #757
1 parent b21f0ed commit 6dd7caa

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ next (unreleased)
88

99
* Remove deprecated Pool.get #706
1010

11+
* Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL #757
12+
1113
0.1.1 (2022-05-08)
1214
^^^^^^^^^^^^^^^^^^
1315

aiomysql/connection.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def connect(host="localhost", user=None, password="",
5353
connect_timeout=None, read_default_group=None,
5454
autocommit=False, echo=False,
5555
local_infile=False, loop=None, ssl=None, auth_plugin='',
56-
program_name='', server_public_key=None):
56+
program_name='', server_public_key=None, implicit_tls=False):
5757
"""See connections.Connection.__init__() for information about
5858
defaults."""
5959
coro = _connect(host=host, user=user, password=password, db=db,
@@ -66,7 +66,8 @@ def connect(host="localhost", user=None, password="",
6666
read_default_group=read_default_group,
6767
autocommit=autocommit, echo=echo,
6868
local_infile=local_infile, loop=loop, ssl=ssl,
69-
auth_plugin=auth_plugin, program_name=program_name)
69+
auth_plugin=auth_plugin, program_name=program_name,
70+
implicit_tls=implicit_tls)
7071
return _ConnectionContextManager(coro)
7172

7273

@@ -142,7 +143,7 @@ def __init__(self, host="localhost", user=None, password="",
142143
connect_timeout=None, read_default_group=None,
143144
autocommit=False, echo=False,
144145
local_infile=False, loop=None, ssl=None, auth_plugin='',
145-
program_name='', server_public_key=None):
146+
program_name='', server_public_key=None, implicit_tls=False):
146147
"""
147148
Establish a connection to the MySQL database. Accepts several
148149
arguments:
@@ -184,6 +185,9 @@ def __init__(self, host="localhost", user=None, password="",
184185
handshaking with MySQL. (omitted by default)
185186
:param server_public_key: SHA256 authentication plugin public
186187
key value.
188+
:param implicit_tls: Establish TLS immediately, skipping non-TLS
189+
preamble before upgrading to TLS.
190+
(default: False)
187191
:param loop: asyncio loop
188192
"""
189193
self._loop = loop or asyncio.get_event_loop()
@@ -218,6 +222,7 @@ def __init__(self, host="localhost", user=None, password="",
218222
self._auth_plugin_used = ""
219223
self._secure = False
220224
self.server_public_key = server_public_key
225+
self._implicit_tls = implicit_tls
221226
self.salt = None
222227

223228
from . import __version__
@@ -241,7 +246,7 @@ def __init__(self, host="localhost", user=None, password="",
241246
self.use_unicode = use_unicode
242247

243248
self._ssl_context = ssl
244-
if ssl:
249+
if ssl and not implicit_tls:
245250
client_flag |= CLIENT.SSL
246251

247252
self._encoding = charset_by_name(self._charset).encoding
@@ -536,7 +541,8 @@ async def _connect(self):
536541

537542
self._next_seq_id = 0
538543

539-
await self._get_server_information()
544+
if not self._implicit_tls:
545+
await self._get_server_information()
540546
await self._request_authentication()
541547

542548
self.connected_time = self._loop.time()
@@ -727,7 +733,8 @@ async def _execute_command(self, command, sql):
727733

728734
async def _request_authentication(self):
729735
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
730-
if int(self.server_version.split('.', 1)[0]) >= 5:
736+
# FIXME: change this before merge
737+
if self._implicit_tls or int(self.server_version.split('.', 1)[0]) >= 5:
731738
self.client_flag |= CLIENT.MULTI_RESULTS
732739

733740
if self.user is None:
@@ -737,8 +744,10 @@ async def _request_authentication(self):
737744
data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
738745
charset_id, b'')
739746

740-
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
741-
self.write_packet(data_init)
747+
if self._ssl_context and \
748+
(self._implicit_tls or self.server_capabilities & CLIENT.SSL):
749+
if not self._implicit_tls:
750+
self.write_packet(data_init)
742751

743752
# Stop sending events to data_received
744753
self._writer.transport.pause_reading()
@@ -760,6 +769,9 @@ async def _request_authentication(self):
760769
server_hostname=self._host
761770
)
762771

772+
if self._implicit_tls:
773+
await self._get_server_information()
774+
763775
self._secure = True
764776

765777
if isinstance(self.user, str):

docs/connection.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Example::
4747
connect_timeout=None, read_default_group=None,
4848
autocommit=False, echo=False
4949
ssl=None, auth_plugin='', program_name='',
50-
server_public_key=None, loop=None)
50+
server_public_key=None, loop=None, implicit_tls=False)
5151

5252
A :ref:`coroutine <coroutine>` that connects to MySQL.
5353

@@ -93,6 +93,11 @@ Example::
9393
``sys.argv[0]`` is no longer passed by default
9494
:param server_public_key: SHA256 authenticaiton plugin public key value.
9595
:param loop: asyncio event loop instance or ``None`` for default one.
96+
:param implicit_tls: Establish TLS immediately, skipping non-TLS
97+
preamble before upgrading to TLS.
98+
(default: False)
99+
100+
.. versionadded:: 0.2
96101
:returns: :class:`Connection` instance.
97102

98103

0 commit comments

Comments
 (0)