Skip to content

Commit c1e8aae

Browse files
committed
fix unit socket implementation, most tests should be fine now
1 parent b115108 commit c1e8aae

File tree

11 files changed

+185
-67
lines changed

11 files changed

+185
-67
lines changed

.github/workflows/ci.yml

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ jobs:
2828
# - '3.10'
2929
# - '3.11.0-alpha.3'
3030
db:
31-
- 'mysql:5.7'
32-
- 'mysql:8.0'
33-
- 'mariadb:10.2'
34-
- 'mariadb:10.3'
35-
- 'mariadb:10.4'
36-
- 'mariadb:10.5'
37-
- 'mariadb:10.6'
38-
- 'mariadb:10.7'
31+
- [mysql, '5.7']
32+
- [mysql, '8.0']
33+
- [mariadb, '10.2']
34+
- [mariadb, '10.3']
35+
- [mariadb, '10.4']
36+
- [mariadb, '10.5']
37+
- [mariadb, '10.6']
38+
- [mariadb, '10.7']
3939

4040
fail-fast: false
4141
runs-on: ${{ matrix.os }}
@@ -46,9 +46,11 @@ jobs:
4646

4747
services:
4848
mysql:
49-
image: '${{ matrix.db }}'
49+
image: "${{ join(matrix.db, ':') }}"
5050
ports:
5151
- 3306:3306
52+
volumes:
53+
- "/tmp/run-${{ join(matrix.db, '-') }}/:/run/mysqld/"
5254
options: '--name=mysqld'
5355
env:
5456
MYSQL_ROOT_PASSWORD: rootpw
@@ -117,29 +119,18 @@ jobs:
117119
118120
- name: Run tests
119121
run: |
120-
export DB="${MATRIX_DB%%:*}"
121-
export DBTAG="${MATRIX_DB##*:}"
122-
123122
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
124123
timeout --preserve-status --signal=INT --verbose 5m \
125-
pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests
124+
pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-unix-socket "${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "${{ join(matrix.db, '') }}=127.0.0.1:3306"
126125
env:
127126
PYTHONUNBUFFERED: 1
128-
MATRIX_DB: '${{ matrix.db }}'
127+
DB: '${{ matrix.db[0] }}'
128+
DBTAG: '${{ matrix.db[1] }}'
129129
timeout-minutes: 6
130130

131-
- name: Build coverage flag
132-
run: |
133-
COVERAGE_FLAG="${MATRIX_OS}_${MATRIX_PY}_${MATRIX_DB//:/-}"
134-
echo "COVERAGE_FLAG=$COVERAGE_FLAG" | tee -a "$GITHUB_ENV"
135-
env:
136-
MATRIX_OS: '${{ matrix.os }}'
137-
MATRIX_PY: '${{ matrix.py }}'
138-
MATRIX_DB: '${{ matrix.db }}'
139-
140131
- name: Upload coverage
141132
uses: codecov/[email protected]
142133
with:
143134
file: ./coverage.xml
144-
flags: "${{ env.COVERAGE_FLAG }}"
135+
flags: "${{ matrix.os }}_${{ matrix.py }}_${{ join(matrix.db, '-') }}"
145136
fail_ci_if_error: true

aiomysql/connection.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(self, host="localhost", user=None, password="",
226226
self._client_auth_plugin = auth_plugin
227227
self._server_auth_plugin = ""
228228
self._auth_plugin_used = ""
229+
self._secure = False
229230
self.server_public_key = server_public_key
230231
self.salt = None
231232

@@ -523,14 +524,15 @@ async def _connect(self):
523524
# raise OperationalError(CR.CR_SERVER_GONE_ERROR,
524525
# "MySQL server has gone away (%r)" % (e,))
525526
try:
526-
if self._unix_socket and self._host in ('localhost', '127.0.0.1'):
527+
if self._unix_socket:
527528
self._reader, self._writer = await \
528529
asyncio.wait_for(
529530
_open_unix_connection(
530531
self._unix_socket),
531532
timeout=self.connect_timeout)
532533
self.host_info = "Localhost via UNIX socket: " + \
533534
self._unix_socket
535+
self._secure = True
534536
else:
535537
self._reader, self._writer = await \
536538
asyncio.wait_for(
@@ -740,7 +742,7 @@ async def _request_authentication(self):
740742
if self.user is None:
741743
raise ValueError("Did not specify a username")
742744

743-
if self._ssl_context:
745+
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
744746
# capablities, max packet, charset
745747
data = struct.pack('<IIB', self.client_flag, 16777216, 33)
746748
data += b'\x00' * (32 - len(data))
@@ -767,6 +769,8 @@ async def _request_authentication(self):
767769
server_hostname=self._host
768770
)
769771

772+
self._secure = True
773+
770774
charset_id = charset_by_name(self.charset).id
771775
if isinstance(self.user, str):
772776
_user = self.user.encode(self.encoding)
@@ -795,7 +799,7 @@ async def _request_authentication(self):
795799
)
796800
# Else: empty password
797801
elif auth_plugin == 'sha256_password':
798-
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
802+
if self._secure:
799803
authresp = self._password.encode('latin1') + b'\0'
800804
elif self._password:
801805
authresp = b'\1' # request public key
@@ -957,7 +961,7 @@ async def caching_sha2_password_auth(self, pkt):
957961

958962
logger.debug("caching sha2: Trying full auth...")
959963

960-
if self._ssl_context:
964+
if self._secure:
961965
logger.debug("caching sha2: Sending plain "
962966
"password via secure connection")
963967
self.write_packet(self._password.encode('latin1') + b'\0')
@@ -988,7 +992,7 @@ async def caching_sha2_password_auth(self, pkt):
988992
pkt.check_error()
989993

990994
async def sha256_password_auth(self, pkt):
991-
if self._ssl_context:
995+
if self._secure:
992996
logger.debug("sha256: Sending plain password")
993997
data = self._password.encode('latin1') + b'\0'
994998
self.write_packet(data)

tests/conftest.py

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,48 @@ def pytest_generate_tests(metafunc):
2727
loop_type = ['asyncio', 'uvloop'] if uvloop else ['asyncio']
2828
metafunc.parametrize("loop_type", loop_type)
2929

30+
if "mysql_address" in metafunc.fixturenames:
31+
mysql_addresses = []
32+
ids = []
33+
34+
opt_mysql_unix_socket = \
35+
list(metafunc.config.getoption("mysql_unix_socket"))
36+
for i in range(len(opt_mysql_unix_socket)):
37+
if "=" in opt_mysql_unix_socket[i]:
38+
label, path = opt_mysql_unix_socket[i].rsplit("=", 1)
39+
mysql_addresses.append(path)
40+
ids.append(label)
41+
else:
42+
mysql_addresses.append(opt_mysql_unix_socket[i])
43+
ids.append("socket{}".format(i))
44+
45+
opt_mysql_address = list(metafunc.config.getoption("mysql_address"))
46+
for i in range(len(opt_mysql_address)):
47+
if "=" in opt_mysql_address[i]:
48+
label, addr = opt_mysql_unix_socket[i].rsplit("=", 1)
49+
ids.append(label)
50+
else:
51+
addr = opt_mysql_address[i]
52+
ids.append("socket{}".format(i))
53+
54+
if ":" in addr:
55+
addr = addr.rsplit(":", 1)
56+
mysql_addresses.append((addr[0], int(addr[1])))
57+
else:
58+
mysql_addresses.append((addr, 3306))
59+
ids.append("address{}".format(i))
60+
61+
# default to connecting to localhost
62+
if len(mysql_addresses) == 0:
63+
mysql_addresses = [("127.0.0.1", 3306)]
64+
ids = ["tcp-local"]
65+
66+
metafunc.parametrize("mysql_address",
67+
mysql_addresses,
68+
ids=ids,
69+
scope="session",
70+
)
71+
3072

3173
# This is here unless someone fixes the generate_tests bit
3274
@pytest.fixture(scope='session')
@@ -101,6 +143,21 @@ def pytest_configure(config):
101143
)
102144

103145

146+
def pytest_addoption(parser):
147+
parser.addoption(
148+
"--mysql-address",
149+
action="append",
150+
default=[],
151+
help="list of addresses to connect to",
152+
)
153+
parser.addoption(
154+
"--mysql-unix-socket",
155+
action="append",
156+
default=[],
157+
help="list of unix sockets to connect to",
158+
)
159+
160+
104161
@pytest.fixture
105162
def mysql_params(mysql_server):
106163
params = {**mysql_server['conn_params'],
@@ -209,24 +266,31 @@ def ensure_mysql_version(request, mysql_image, mysql_tag):
209266

210267

211268
@pytest.fixture(scope='session')
212-
def mysql_server(mysql_image, mysql_tag):
213-
ssl_directory = os.path.join(os.path.dirname(__file__),
214-
'ssl_resources', 'ssl')
215-
ca_file = os.path.join(ssl_directory, 'ca.pem')
269+
def mysql_server(mysql_image, mysql_tag, mysql_address):
270+
unix_socket = type(mysql_address) is str
271+
272+
if not unix_socket:
273+
ssl_directory = os.path.join(os.path.dirname(__file__),
274+
'ssl_resources', 'ssl')
275+
ca_file = os.path.join(ssl_directory, 'ca.pem')
216276

217-
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
218-
ctx.check_hostname = False
219-
ctx.load_verify_locations(cafile=ca_file)
220-
# ctx.verify_mode = ssl.CERT_NONE
277+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
278+
ctx.check_hostname = False
279+
ctx.load_verify_locations(cafile=ca_file)
280+
# ctx.verify_mode = ssl.CERT_NONE
221281

222282
server_params = {
223-
'host': '127.0.0.1',
224-
'port': 3306,
225283
'user': 'root',
226284
'password': os.environ.get("MYSQL_ROOT_PASSWORD"),
227-
'ssl': ctx,
228285
}
229286

287+
if unix_socket:
288+
server_params["unix_socket"] = mysql_address
289+
else:
290+
server_params["host"] = mysql_address[0]
291+
server_params["port"] = mysql_address[1]
292+
server_params["ssl"] = ctx
293+
230294
try:
231295
connection = pymysql.connect(
232296
db='mysql',
@@ -235,21 +299,22 @@ def mysql_server(mysql_image, mysql_tag):
235299
**server_params)
236300

237301
with connection.cursor() as cursor:
238-
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")
302+
if not unix_socket:
303+
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")
239304

240-
result = cursor.fetchall()
241-
result = {item['Variable_name']:
242-
item['Value'] for item in result}
305+
result = cursor.fetchall()
306+
result = {item['Variable_name']:
307+
item['Value'] for item in result}
243308

244-
assert result['have_ssl'] == "YES", \
245-
"SSL Not Enabled on MySQL"
309+
assert result['have_ssl'] == "YES", \
310+
"SSL Not Enabled on MySQL"
246311

247-
cursor.execute("SHOW STATUS LIKE 'Ssl_version%'")
312+
cursor.execute("SHOW STATUS LIKE 'Ssl_version%'")
248313

249-
result = cursor.fetchone()
250-
# As we connected with TLS, it should start with that :D
251-
assert result['Value'].startswith('TLS'), \
252-
"Not connected to the database with TLS"
314+
result = cursor.fetchone()
315+
# As we connected with TLS, it should start with that :D
316+
assert result['Value'].startswith('TLS'), \
317+
"Not connected to the database with TLS"
253318

254319
# Drop possibly existing old databases
255320
cursor.execute('DROP DATABASE IF EXISTS test_pymysql;')
File renamed without changes.

tests/fixtures/my.cnf.unix.tmpl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# The MySQL database server configuration file.
3+
#
4+
[client]
5+
user = {user}
6+
socket = {unix_socket}
7+
password = {password}
8+
database = {db}
9+
default-character-set = utf8
10+
11+
[client_with_unix_socket]
12+
user = {user}
13+
socket = {unix_socket}
14+
password = {password}
15+
database = {db}
16+
default-character-set = utf8

tests/sa/test_sa_compiled_cache.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
@pytest.fixture()
1616
def make_engine(mysql_params, connection):
1717
async def _make_engine(**kwargs):
18+
if "unix_socket" in mysql_params:
19+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
20+
else:
21+
conn_args = {
22+
"host": mysql_params['host'],
23+
"port": mysql_params['port'],
24+
}
25+
1826
return (await sa.create_engine(db=mysql_params['db'],
1927
user=mysql_params['user'],
2028
password=mysql_params['password'],
21-
host=mysql_params['host'],
22-
port=mysql_params['port'],
2329
minsize=10,
30+
**conn_args,
2431
**kwargs))
2532

2633
return _make_engine

tests/sa/test_sa_default.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,19 @@
2222
@pytest.fixture()
2323
def make_engine(mysql_params, connection):
2424
async def _make_engine(**kwargs):
25+
if "unix_socket" in mysql_params:
26+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
27+
else:
28+
conn_args = {
29+
"host": mysql_params['host'],
30+
"port": mysql_params['port'],
31+
}
32+
2533
return (await sa.create_engine(db=mysql_params['db'],
2634
user=mysql_params['user'],
2735
password=mysql_params['password'],
28-
host=mysql_params['host'],
29-
port=mysql_params['port'],
3036
minsize=10,
37+
**conn_args,
3138
**kwargs))
3239

3340
return _make_engine

tests/sa/test_sa_engine.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
@pytest.fixture()
1616
def make_engine(connection, mysql_params):
1717
async def _make_engine(**kwargs):
18+
if "unix_socket" in mysql_params:
19+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
20+
else:
21+
conn_args = {
22+
"host": mysql_params['host'],
23+
"port": mysql_params['port'],
24+
}
25+
1826
return (await sa.create_engine(db=mysql_params['db'],
1927
user=mysql_params['user'],
2028
password=mysql_params['password'],
21-
host=mysql_params['host'],
22-
port=mysql_params['port'],
2329
minsize=10,
30+
**conn_args,
2431
**kwargs))
2532
return _make_engine
2633

0 commit comments

Comments
 (0)