Skip to content

Commit 5f83273

Browse files
committed
fix unit socket implementation, most tests should be fine now
configure custom socket path for mysql container, working around implicitly created volume folders being owned by root we should probably just not use service containers for this to avoid having to do this patching
1 parent 6191960 commit 5f83273

File tree

12 files changed

+191
-44
lines changed

12 files changed

+191
-44
lines changed

.github/workflows/ci.yml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ jobs:
4949
image: '${{ matrix.db }}'
5050
ports:
5151
- 3306:3306
52+
volumes:
53+
- "/tmp/run-${{ join(matrix.db, '-') }}/:/socket-mount/"
5254
options: '--name=mysqld'
5355
env:
5456
MYSQL_ROOT_PASSWORD: rootpw
@@ -104,6 +106,19 @@ jobs:
104106
docker container stop mysqld
105107
docker container cp "${{ github.workspace }}/tests/ssl_resources/ssl" mysqld:/etc/mysql/ssl
106108
docker container cp "${{ github.workspace }}/tests/ssl_resources/tls.cnf" mysqld:/etc/mysql/conf.d/aiomysql-tls.cnf
109+
110+
# use custom socket path
111+
# we need to ensure that the socket path is writable for the user running the DB process in the container
112+
sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }}
113+
114+
# mysql 5.7 container overrides the socket path in /etc/mysql/mysql.conf.d/mysqld.cnf
115+
if [ "${{ join(matrix.db, '-') }}" = "mysql-5.7" ]
116+
then
117+
docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/mysql.conf.d/zz-aiomysql-socket.cnf
118+
else
119+
docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/conf.d/aiomysql-socket.cnf
120+
fi
121+
107122
docker container start mysqld
108123
109124
# ensure server is started up
@@ -122,7 +137,7 @@ jobs:
122137
123138
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
124139
timeout --preserve-status --signal=INT --verbose 5m \
125-
pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests
140+
pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
126141
env:
127142
PYTHONUNBUFFERED: 1
128143
MATRIX_DB: '${{ matrix.db }}'

aiomysql/connection.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(self, host="localhost", user=None, password="",
229229
self._client_auth_plugin = auth_plugin
230230
self._server_auth_plugin = ""
231231
self._auth_plugin_used = ""
232+
self._secure = False
232233
self.server_public_key = server_public_key
233234
self.salt = None
234235

@@ -526,14 +527,15 @@ async def _connect(self):
526527
# raise OperationalError(CR.CR_SERVER_GONE_ERROR,
527528
# "MySQL server has gone away (%r)" % (e,))
528529
try:
529-
if self._unix_socket and self._host in ('localhost', '127.0.0.1'):
530+
if self._unix_socket:
530531
self._reader, self._writer = await \
531532
asyncio.wait_for(
532533
_open_unix_connection(
533534
self._unix_socket),
534535
timeout=self.connect_timeout)
535536
self.host_info = "Localhost via UNIX socket: " + \
536537
self._unix_socket
538+
self._secure = True
537539
else:
538540
self._reader, self._writer = await \
539541
asyncio.wait_for(
@@ -743,7 +745,7 @@ async def _request_authentication(self):
743745
if self.user is None:
744746
raise ValueError("Did not specify a username")
745747

746-
if self._ssl_context:
748+
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
747749
# capablities, max packet, charset
748750
data = struct.pack('<IIB', self.client_flag, 16777216, 33)
749751
data += b'\x00' * (32 - len(data))
@@ -770,6 +772,8 @@ async def _request_authentication(self):
770772
server_hostname=self._host
771773
)
772774

775+
self._secure = True
776+
773777
charset_id = charset_by_name(self.charset).id
774778
if isinstance(self.user, str):
775779
_user = self.user.encode(self.encoding)
@@ -798,7 +802,7 @@ async def _request_authentication(self):
798802
)
799803
# Else: empty password
800804
elif auth_plugin == 'sha256_password':
801-
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
805+
if self._secure:
802806
authresp = self._password.encode('latin1') + b'\0'
803807
elif self._password:
804808
authresp = b'\1' # request public key
@@ -960,7 +964,7 @@ async def caching_sha2_password_auth(self, pkt):
960964

961965
logger.debug("caching sha2: Trying full auth...")
962966

963-
if self._ssl_context:
967+
if self._secure:
964968
logger.debug("caching sha2: Sending plain "
965969
"password via secure connection")
966970
self.write_packet(self._password.encode('latin1') + b'\0')
@@ -991,7 +995,7 @@ async def caching_sha2_password_auth(self, pkt):
991995
pkt.check_error()
992996

993997
async def sha256_password_auth(self, pkt):
994-
if self._ssl_context:
998+
if self._secure:
995999
logger.debug("sha256: Sending plain password")
9961000
data = self._password.encode('latin1') + b'\0'
9971001
self.write_packet(data)

tests/conftest.py

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,51 @@ def pytest_generate_tests(metafunc):
2727
loop_type = ['asyncio', 'uvloop'] if uvloop else ['asyncio']
2828
metafunc.parametrize("loop_type", loop_type)
2929

30+
if "mysql_address" in metafunc.fixturenames:
31+
mysql_addresses = []
32+
ids = []
33+
34+
opt_mysql_unix_socket = \
35+
list(metafunc.config.getoption("mysql_unix_socket"))
36+
for i in range(len(opt_mysql_unix_socket)):
37+
if "=" in opt_mysql_unix_socket[i]:
38+
label, path = opt_mysql_unix_socket[i].rsplit("=", 1)
39+
mysql_addresses.append(path)
40+
ids.append(label)
41+
else:
42+
mysql_addresses.append(opt_mysql_unix_socket[i])
43+
ids.append("unix{}".format(i))
44+
45+
opt_mysql_address = list(metafunc.config.getoption("mysql_address"))
46+
for i in range(len(opt_mysql_address)):
47+
if "=" in opt_mysql_address[i]:
48+
label, addr = opt_mysql_address[i].rsplit("=", 1)
49+
ids.append(label)
50+
else:
51+
addr = opt_mysql_address[i]
52+
ids.append("tcp{}".format(i))
53+
54+
if ":" in addr:
55+
addr = addr.rsplit(":", 1)
56+
mysql_addresses.append((addr[0], int(addr[1])))
57+
else:
58+
mysql_addresses.append((addr, 3306))
59+
60+
# default to connecting to localhost
61+
if len(mysql_addresses) == 0:
62+
mysql_addresses = [("127.0.0.1", 3306)]
63+
ids = ["tcp-local"]
64+
65+
assert len(mysql_addresses) == len(set(mysql_addresses))
66+
assert len(ids) == len(set(ids))
67+
assert len(mysql_addresses) == len(ids)
68+
69+
metafunc.parametrize("mysql_address",
70+
mysql_addresses,
71+
ids=ids,
72+
scope="session",
73+
)
74+
3075

3176
# This is here unless someone fixes the generate_tests bit
3277
@pytest.fixture(scope='session')
@@ -101,6 +146,21 @@ def pytest_configure(config):
101146
)
102147

103148

149+
def pytest_addoption(parser):
150+
parser.addoption(
151+
"--mysql-address",
152+
action="append",
153+
default=[],
154+
help="list of addresses to connect to",
155+
)
156+
parser.addoption(
157+
"--mysql-unix-socket",
158+
action="append",
159+
default=[],
160+
help="list of unix sockets to connect to",
161+
)
162+
163+
104164
@pytest.fixture
105165
def mysql_params(mysql_server):
106166
params = {**mysql_server['conn_params'],
@@ -205,24 +265,31 @@ def ensure_mysql_version(request, mysql_image, mysql_tag):
205265

206266

207267
@pytest.fixture(scope='session')
208-
def mysql_server(mysql_image, mysql_tag):
209-
ssl_directory = os.path.join(os.path.dirname(__file__),
210-
'ssl_resources', 'ssl')
211-
ca_file = os.path.join(ssl_directory, 'ca.pem')
268+
def mysql_server(mysql_image, mysql_tag, mysql_address):
269+
unix_socket = type(mysql_address) is str
270+
271+
if not unix_socket:
272+
ssl_directory = os.path.join(os.path.dirname(__file__),
273+
'ssl_resources', 'ssl')
274+
ca_file = os.path.join(ssl_directory, 'ca.pem')
212275

213-
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
214-
ctx.check_hostname = False
215-
ctx.load_verify_locations(cafile=ca_file)
216-
# ctx.verify_mode = ssl.CERT_NONE
276+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
277+
ctx.check_hostname = False
278+
ctx.load_verify_locations(cafile=ca_file)
279+
# ctx.verify_mode = ssl.CERT_NONE
217280

218281
server_params = {
219-
'host': '127.0.0.1',
220-
'port': 3306,
221282
'user': 'root',
222283
'password': os.environ.get("MYSQL_ROOT_PASSWORD"),
223-
'ssl': ctx,
224284
}
225285

286+
if unix_socket:
287+
server_params["unix_socket"] = mysql_address
288+
else:
289+
server_params["host"] = mysql_address[0]
290+
server_params["port"] = mysql_address[1]
291+
server_params["ssl"] = ctx
292+
226293
try:
227294
connection = pymysql.connect(
228295
db='mysql',
@@ -231,21 +298,22 @@ def mysql_server(mysql_image, mysql_tag):
231298
**server_params)
232299

233300
with connection.cursor() as cursor:
234-
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")
301+
if not unix_socket:
302+
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")
235303

236-
result = cursor.fetchall()
237-
result = {item['Variable_name']:
238-
item['Value'] for item in result}
304+
result = cursor.fetchall()
305+
result = {item['Variable_name']:
306+
item['Value'] for item in result}
239307

240-
assert result['have_ssl'] == "YES", \
241-
"SSL Not Enabled on MySQL"
308+
assert result['have_ssl'] == "YES", \
309+
"SSL Not Enabled on MySQL"
242310

243-
cursor.execute("SHOW STATUS LIKE 'Ssl_version%'")
311+
cursor.execute("SHOW STATUS LIKE 'Ssl_version%'")
244312

245-
result = cursor.fetchone()
246-
# As we connected with TLS, it should start with that :D
247-
assert result['Value'].startswith('TLS'), \
248-
"Not connected to the database with TLS"
313+
result = cursor.fetchone()
314+
# As we connected with TLS, it should start with that :D
315+
assert result['Value'].startswith('TLS'), \
316+
"Not connected to the database with TLS"
249317

250318
# Drop possibly existing old databases
251319
cursor.execute('DROP DATABASE IF EXISTS test_pymysql;')
File renamed without changes.

tests/fixtures/my.cnf.unix.tmpl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# The MySQL database server configuration file.
3+
#
4+
[client]
5+
user = {user}
6+
socket = {unix_socket}
7+
password = {password}
8+
database = {db}
9+
default-character-set = utf8
10+
11+
[client_with_unix_socket]
12+
user = {user}
13+
socket = {unix_socket}
14+
password = {password}
15+
database = {db}
16+
default-character-set = utf8

tests/sa/test_sa_compiled_cache.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
@pytest.fixture()
1616
def make_engine(mysql_params, connection):
1717
async def _make_engine(**kwargs):
18+
if "unix_socket" in mysql_params:
19+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
20+
else:
21+
conn_args = {
22+
"host": mysql_params['host'],
23+
"port": mysql_params['port'],
24+
}
25+
1826
return (await sa.create_engine(db=mysql_params['db'],
1927
user=mysql_params['user'],
2028
password=mysql_params['password'],
21-
host=mysql_params['host'],
22-
port=mysql_params['port'],
2329
minsize=10,
30+
**conn_args,
2431
**kwargs))
2532

2633
return _make_engine

tests/sa/test_sa_default.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,19 @@
2222
@pytest.fixture()
2323
def make_engine(mysql_params, connection):
2424
async def _make_engine(**kwargs):
25+
if "unix_socket" in mysql_params:
26+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
27+
else:
28+
conn_args = {
29+
"host": mysql_params['host'],
30+
"port": mysql_params['port'],
31+
}
32+
2533
return (await sa.create_engine(db=mysql_params['db'],
2634
user=mysql_params['user'],
2735
password=mysql_params['password'],
28-
host=mysql_params['host'],
29-
port=mysql_params['port'],
3036
minsize=10,
37+
**conn_args,
3138
**kwargs))
3239

3340
return _make_engine

tests/sa/test_sa_engine.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
@pytest.fixture()
1616
def make_engine(connection, mysql_params):
1717
async def _make_engine(**kwargs):
18+
if "unix_socket" in mysql_params:
19+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
20+
else:
21+
conn_args = {
22+
"host": mysql_params['host'],
23+
"port": mysql_params['port'],
24+
}
25+
1826
return (await sa.create_engine(db=mysql_params['db'],
1927
user=mysql_params['user'],
2028
password=mysql_params['password'],
21-
host=mysql_params['host'],
22-
port=mysql_params['port'],
2329
minsize=10,
30+
**conn_args,
2431
**kwargs))
2532
return _make_engine
2633

tests/ssl_resources/socket.cnf

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[mysqld]
2+
socket = /socket-mount/mysql.sock

0 commit comments

Comments
 (0)