Skip to content

Commit 163a791

Browse files
committed
fix unit socket implementation, most tests should be fine now
1 parent 62fd1b2 commit 163a791

File tree

11 files changed

+173
-44
lines changed

11 files changed

+173
-44
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ jobs:
4949
image: "${{ join(matrix.db, ':') }}"
5050
ports:
5151
- 3306:3306
52+
volumes:
53+
- "/tmp/run-${{ join(matrix.db, '-') }}/:/run/mysqld/"
5254
options: '--name=mysqld'
5355
env:
5456
MYSQL_ROOT_PASSWORD: rootpw
@@ -119,7 +121,7 @@ jobs:
119121
run: |
120122
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
121123
timeout --preserve-status --signal=INT --verbose 5m \
122-
pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests
124+
pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-unix-socket "${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "${{ join(matrix.db, '') }}=127.0.0.1:3306"
123125
env:
124126
PYTHONUNBUFFERED: 1
125127
DB: '${{ matrix.db[0] }}'

aiomysql/connection.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(self, host="localhost", user=None, password="",
229229
self._client_auth_plugin = auth_plugin
230230
self._server_auth_plugin = ""
231231
self._auth_plugin_used = ""
232+
self._secure = False
232233
self.server_public_key = server_public_key
233234
self.salt = None
234235

@@ -526,14 +527,15 @@ async def _connect(self):
526527
# raise OperationalError(CR.CR_SERVER_GONE_ERROR,
527528
# "MySQL server has gone away (%r)" % (e,))
528529
try:
529-
if self._unix_socket and self._host in ('localhost', '127.0.0.1'):
530+
if self._unix_socket:
530531
self._reader, self._writer = await \
531532
asyncio.wait_for(
532533
_open_unix_connection(
533534
self._unix_socket),
534535
timeout=self.connect_timeout)
535536
self.host_info = "Localhost via UNIX socket: " + \
536537
self._unix_socket
538+
self._secure = True
537539
else:
538540
self._reader, self._writer = await \
539541
asyncio.wait_for(
@@ -743,7 +745,7 @@ async def _request_authentication(self):
743745
if self.user is None:
744746
raise ValueError("Did not specify a username")
745747

746-
if self._ssl_context:
748+
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
747749
# capablities, max packet, charset
748750
data = struct.pack('<IIB', self.client_flag, 16777216, 33)
749751
data += b'\x00' * (32 - len(data))
@@ -770,6 +772,8 @@ async def _request_authentication(self):
770772
server_hostname=self._host
771773
)
772774

775+
self._secure = True
776+
773777
charset_id = charset_by_name(self.charset).id
774778
if isinstance(self.user, str):
775779
_user = self.user.encode(self.encoding)
@@ -798,7 +802,7 @@ async def _request_authentication(self):
798802
)
799803
# Else: empty password
800804
elif auth_plugin == 'sha256_password':
801-
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
805+
if self._secure:
802806
authresp = self._password.encode('latin1') + b'\0'
803807
elif self._password:
804808
authresp = b'\1' # request public key
@@ -960,7 +964,7 @@ async def caching_sha2_password_auth(self, pkt):
960964

961965
logger.debug("caching sha2: Trying full auth...")
962966

963-
if self._ssl_context:
967+
if self._secure:
964968
logger.debug("caching sha2: Sending plain "
965969
"password via secure connection")
966970
self.write_packet(self._password.encode('latin1') + b'\0')
@@ -991,7 +995,7 @@ async def caching_sha2_password_auth(self, pkt):
991995
pkt.check_error()
992996

993997
async def sha256_password_auth(self, pkt):
994-
if self._ssl_context:
998+
if self._secure:
995999
logger.debug("sha256: Sending plain password")
9961000
data = self._password.encode('latin1') + b'\0'
9971001
self.write_packet(data)

tests/conftest.py

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,48 @@ def pytest_generate_tests(metafunc):
2727
loop_type = ['asyncio', 'uvloop'] if uvloop else ['asyncio']
2828
metafunc.parametrize("loop_type", loop_type)
2929

30+
if "mysql_address" in metafunc.fixturenames:
31+
mysql_addresses = []
32+
ids = []
33+
34+
opt_mysql_unix_socket = \
35+
list(metafunc.config.getoption("mysql_unix_socket"))
36+
for i in range(len(opt_mysql_unix_socket)):
37+
if "=" in opt_mysql_unix_socket[i]:
38+
label, path = opt_mysql_unix_socket[i].rsplit("=", 1)
39+
mysql_addresses.append(path)
40+
ids.append(label)
41+
else:
42+
mysql_addresses.append(opt_mysql_unix_socket[i])
43+
ids.append("socket{}".format(i))
44+
45+
opt_mysql_address = list(metafunc.config.getoption("mysql_address"))
46+
for i in range(len(opt_mysql_address)):
47+
if "=" in opt_mysql_address[i]:
48+
label, addr = opt_mysql_unix_socket[i].rsplit("=", 1)
49+
ids.append(label)
50+
else:
51+
addr = opt_mysql_address[i]
52+
ids.append("socket{}".format(i))
53+
54+
if ":" in addr:
55+
addr = addr.rsplit(":", 1)
56+
mysql_addresses.append((addr[0], int(addr[1])))
57+
else:
58+
mysql_addresses.append((addr, 3306))
59+
ids.append("address{}".format(i))
60+
61+
# default to connecting to localhost
62+
if len(mysql_addresses) == 0:
63+
mysql_addresses = [("127.0.0.1", 3306)]
64+
ids = ["tcp-local"]
65+
66+
metafunc.parametrize("mysql_address",
67+
mysql_addresses,
68+
ids=ids,
69+
scope="session",
70+
)
71+
3072

3173
# This is here unless someone fixes the generate_tests bit
3274
@pytest.fixture(scope='session')
@@ -101,6 +143,21 @@ def pytest_configure(config):
101143
)
102144

103145

146+
def pytest_addoption(parser):
147+
parser.addoption(
148+
"--mysql-address",
149+
action="append",
150+
default=[],
151+
help="list of addresses to connect to",
152+
)
153+
parser.addoption(
154+
"--mysql-unix-socket",
155+
action="append",
156+
default=[],
157+
help="list of unix sockets to connect to",
158+
)
159+
160+
104161
@pytest.fixture
105162
def mysql_params(mysql_server):
106163
params = {**mysql_server['conn_params'],
@@ -205,24 +262,31 @@ def ensure_mysql_version(request, mysql_image, mysql_tag):
205262

206263

207264
@pytest.fixture(scope='session')
208-
def mysql_server(mysql_image, mysql_tag):
209-
ssl_directory = os.path.join(os.path.dirname(__file__),
210-
'ssl_resources', 'ssl')
211-
ca_file = os.path.join(ssl_directory, 'ca.pem')
265+
def mysql_server(mysql_image, mysql_tag, mysql_address):
266+
unix_socket = type(mysql_address) is str
267+
268+
if not unix_socket:
269+
ssl_directory = os.path.join(os.path.dirname(__file__),
270+
'ssl_resources', 'ssl')
271+
ca_file = os.path.join(ssl_directory, 'ca.pem')
212272

213-
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
214-
ctx.check_hostname = False
215-
ctx.load_verify_locations(cafile=ca_file)
216-
# ctx.verify_mode = ssl.CERT_NONE
273+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
274+
ctx.check_hostname = False
275+
ctx.load_verify_locations(cafile=ca_file)
276+
# ctx.verify_mode = ssl.CERT_NONE
217277

218278
server_params = {
219-
'host': '127.0.0.1',
220-
'port': 3306,
221279
'user': 'root',
222280
'password': os.environ.get("MYSQL_ROOT_PASSWORD"),
223-
'ssl': ctx,
224281
}
225282

283+
if unix_socket:
284+
server_params["unix_socket"] = mysql_address
285+
else:
286+
server_params["host"] = mysql_address[0]
287+
server_params["port"] = mysql_address[1]
288+
server_params["ssl"] = ctx
289+
226290
try:
227291
connection = pymysql.connect(
228292
db='mysql',
@@ -231,21 +295,22 @@ def mysql_server(mysql_image, mysql_tag):
231295
**server_params)
232296

233297
with connection.cursor() as cursor:
234-
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")
298+
if not unix_socket:
299+
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")
235300

236-
result = cursor.fetchall()
237-
result = {item['Variable_name']:
238-
item['Value'] for item in result}
301+
result = cursor.fetchall()
302+
result = {item['Variable_name']:
303+
item['Value'] for item in result}
239304

240-
assert result['have_ssl'] == "YES", \
241-
"SSL Not Enabled on MySQL"
305+
assert result['have_ssl'] == "YES", \
306+
"SSL Not Enabled on MySQL"
242307

243-
cursor.execute("SHOW STATUS LIKE 'Ssl_version%'")
308+
cursor.execute("SHOW STATUS LIKE 'Ssl_version%'")
244309

245-
result = cursor.fetchone()
246-
# As we connected with TLS, it should start with that :D
247-
assert result['Value'].startswith('TLS'), \
248-
"Not connected to the database with TLS"
310+
result = cursor.fetchone()
311+
# As we connected with TLS, it should start with that :D
312+
assert result['Value'].startswith('TLS'), \
313+
"Not connected to the database with TLS"
249314

250315
# Drop possibly existing old databases
251316
cursor.execute('DROP DATABASE IF EXISTS test_pymysql;')
File renamed without changes.

tests/fixtures/my.cnf.unix.tmpl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# The MySQL database server configuration file.
3+
#
4+
[client]
5+
user = {user}
6+
socket = {unix_socket}
7+
password = {password}
8+
database = {db}
9+
default-character-set = utf8
10+
11+
[client_with_unix_socket]
12+
user = {user}
13+
socket = {unix_socket}
14+
password = {password}
15+
database = {db}
16+
default-character-set = utf8

tests/sa/test_sa_compiled_cache.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
@pytest.fixture()
1616
def make_engine(mysql_params, connection):
1717
async def _make_engine(**kwargs):
18+
if "unix_socket" in mysql_params:
19+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
20+
else:
21+
conn_args = {
22+
"host": mysql_params['host'],
23+
"port": mysql_params['port'],
24+
}
25+
1826
return (await sa.create_engine(db=mysql_params['db'],
1927
user=mysql_params['user'],
2028
password=mysql_params['password'],
21-
host=mysql_params['host'],
22-
port=mysql_params['port'],
2329
minsize=10,
30+
**conn_args,
2431
**kwargs))
2532

2633
return _make_engine

tests/sa/test_sa_default.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,19 @@
2222
@pytest.fixture()
2323
def make_engine(mysql_params, connection):
2424
async def _make_engine(**kwargs):
25+
if "unix_socket" in mysql_params:
26+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
27+
else:
28+
conn_args = {
29+
"host": mysql_params['host'],
30+
"port": mysql_params['port'],
31+
}
32+
2533
return (await sa.create_engine(db=mysql_params['db'],
2634
user=mysql_params['user'],
2735
password=mysql_params['password'],
28-
host=mysql_params['host'],
29-
port=mysql_params['port'],
3036
minsize=10,
37+
**conn_args,
3138
**kwargs))
3239

3340
return _make_engine

tests/sa/test_sa_engine.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
@pytest.fixture()
1616
def make_engine(connection, mysql_params):
1717
async def _make_engine(**kwargs):
18+
if "unix_socket" in mysql_params:
19+
conn_args = {"unix_socket": mysql_params["unix_socket"]}
20+
else:
21+
conn_args = {
22+
"host": mysql_params['host'],
23+
"port": mysql_params['port'],
24+
}
25+
1826
return (await sa.create_engine(db=mysql_params['db'],
1927
user=mysql_params['user'],
2028
password=mysql_params['password'],
21-
host=mysql_params['host'],
22-
port=mysql_params['port'],
2329
minsize=10,
30+
**conn_args,
2431
**kwargs))
2532
return _make_engine
2633

tests/test_connection.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
@pytest.fixture()
1111
def fill_my_cnf(mysql_params):
1212
tests_root = os.path.abspath(os.path.dirname(__file__))
13-
path1 = os.path.join(tests_root, 'fixtures/my.cnf.tmpl')
13+
14+
if "unix_socket" in mysql_params:
15+
tmpl_path = "fixtures/my.cnf.unix.tmpl"
16+
else:
17+
tmpl_path = "fixtures/my.cnf.tcp.tmpl"
18+
19+
path1 = os.path.join(tests_root, tmpl_path)
1420
path2 = os.path.join(tests_root, 'fixtures/my.cnf')
1521
with open(path1) as f1:
1622
tmpl = f1.read()
@@ -31,8 +37,11 @@ async def test_config_file(fill_my_cnf, connection_creator, mysql_params):
3137
path = os.path.join(tests_root, 'fixtures/my.cnf')
3238
conn = await connection_creator(read_default_file=path)
3339

34-
assert conn.host == mysql_params['host']
35-
assert conn.port == mysql_params['port']
40+
if "unix_socket" in mysql_params:
41+
assert conn.unix_socket == mysql_params["unix_socket"]
42+
else:
43+
assert conn.host == mysql_params['host']
44+
assert conn.port == mysql_params['port']
3645
assert conn.user, mysql_params['user']
3746

3847
# make sure connection is working
@@ -167,12 +176,15 @@ async def test_connection_gone_away(connection_creator):
167176

168177

169178
@pytest.mark.run_loop
170-
async def test_connection_info_methods(connection_creator):
179+
async def test_connection_info_methods(connection_creator, mysql_params):
171180
conn = await connection_creator()
172181
# trhead id is int
173182
assert isinstance(conn.thread_id(), int)
174183
assert conn.character_set_name() in ('latin1', 'utf8mb4')
175-
assert str(conn.port) in conn.get_host_info()
184+
if "unix_socket" in mysql_params:
185+
assert mysql_params["unix_socket"] in conn.get_host_info()
186+
else:
187+
assert str(conn.port) in conn.get_host_info()
176188
assert isinstance(conn.get_server_info(), str)
177189
# protocol id is int
178190
assert isinstance(conn.get_proto_info(), int)
@@ -200,8 +212,11 @@ async def test_connection_ping(connection_creator):
200212
@pytest.mark.run_loop
201213
async def test_connection_properties(connection_creator, mysql_params):
202214
conn = await connection_creator()
203-
assert conn.host == mysql_params['host']
204-
assert conn.port == mysql_params['port']
215+
if "unix_socket" in mysql_params:
216+
assert conn.unix_socket == mysql_params["unix_socket"]
217+
else:
218+
assert conn.host == mysql_params['host']
219+
assert conn.port == mysql_params['port']
205220
assert conn.user == mysql_params['user']
206221
assert conn.db == mysql_params['db']
207222
assert conn.echo is False

tests/test_issues.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ async def test_issue_17(connection, connection_creator, mysql_params):
184184
async def test_issue_34(connection_creator):
185185
try:
186186
await connection_creator(host="localhost", port=1237,
187-
user="root")
187+
user="root", unix_socket=None)
188188
pytest.fail()
189189
except aiomysql.OperationalError as e:
190190
assert 2003 == e.args[0]

0 commit comments

Comments
 (0)