From 549dd461f2a0724143cd8786e0b5a3676aeb1922 Mon Sep 17 00:00:00 2001 From: Richard Schwab Date: Wed, 19 Jan 2022 19:29:51 +0100 Subject: [PATCH] fix unit socket implementation, most tests should be fine now configure custom socket path for mysql container, working around implicitly created volume folders being owned by root we should probably just not use service containers for this to avoid having to do this patching --- .github/workflows/ci.yml | 17 ++++- CHANGES.txt | 1 + tests/conftest.py | 67 +++++++++++++------ .../fixtures/{my.cnf.tmpl => my.cnf.tcp.tmpl} | 0 tests/fixtures/my.cnf.unix.tmpl | 16 +++++ tests/sa/test_sa_compiled_cache.py | 11 ++- tests/sa/test_sa_default.py | 11 ++- tests/sa/test_sa_engine.py | 11 ++- tests/ssl_resources/socket.cnf | 2 + tests/test_connection.py | 29 ++++++-- tests/test_issues.py | 2 +- tests/test_sha_connection.py | 7 ++ tests/test_ssl.py | 10 ++- 13 files changed, 146 insertions(+), 38 deletions(-) rename tests/fixtures/{my.cnf.tmpl => my.cnf.tcp.tmpl} (100%) create mode 100644 tests/fixtures/my.cnf.unix.tmpl create mode 100644 tests/ssl_resources/socket.cnf diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 83465c67..b8db8d57 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,6 +49,8 @@ jobs: image: "${{ join(matrix.db, ':') }}" ports: - 3306:3306 + volumes: + - "/tmp/run-${{ join(matrix.db, '-') }}/:/socket-mount/" options: '--name=mysqld' env: MYSQL_ROOT_PASSWORD: rootpw @@ -104,6 +106,19 @@ jobs: docker container stop mysqld docker container cp "${{ github.workspace }}/tests/ssl_resources/ssl" mysqld:/etc/mysql/ssl docker container cp "${{ github.workspace }}/tests/ssl_resources/tls.cnf" mysqld:/etc/mysql/conf.d/aiomysql-tls.cnf + + # use custom socket path + # we need to ensure that the socket path is writable for the user running the DB process in the container + sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }} + + # mysql 5.7 container overrides the socket path in /etc/mysql/mysql.conf.d/mysqld.cnf + if [ "${{ join(matrix.db, '-') }}" = "mysql-5.7" ] + then + docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/mysql.conf.d/zz-aiomysql-socket.cnf + else + docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/conf.d/aiomysql-socket.cnf + fi + docker container start mysqld # ensure server is started up @@ -119,7 +134,7 @@ jobs: run: | # timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs timeout --preserve-status --signal=INT --verbose 5m \ - pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306" + pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306" env: PYTHONUNBUFFERED: 1 DB: '${{ matrix.db[0] }}' diff --git a/CHANGES.txt b/CHANGES.txt index b537f279..9262950a 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -11,6 +11,7 @@ To be included in 1.0.0 (unreleased) * Ensure connections are properly closed before raising an OperationalError when the server connection is lost #660 * Ensure connections are properly closed before raising an InternalError when packet sequence numbers are out of sync #660 * Unix sockets are now internally considered secure, allowing sha256_password and caching_sha2_password auth methods to be used #695 +* Test suite now also tests unix socket connections #686 0.0.22 (2021-11-14) diff --git a/tests/conftest.py b/tests/conftest.py index a42172c1..d6b0a923 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,17 @@ def pytest_generate_tests(metafunc): mysql_addresses = [] ids = [] + opt_mysql_unix_socket = \ + list(metafunc.config.getoption("mysql_unix_socket")) + for i in range(len(opt_mysql_unix_socket)): + if "=" in opt_mysql_unix_socket[i]: + label, path = opt_mysql_unix_socket[i].split("=", 1) + mysql_addresses.append(path) + ids.append(label) + else: + mysql_addresses.append(opt_mysql_unix_socket[i]) + ids.append("unix{}".format(i)) + opt_mysql_address = list(metafunc.config.getoption("mysql_address")) for i in range(len(opt_mysql_address)): if "=" in opt_mysql_address[i]: @@ -143,6 +154,12 @@ def pytest_addoption(parser): default=[], help="list of addresses to connect to: [name=]host[:port]", ) + parser.addoption( + "--mysql-unix-socket", + action="append", + default=[], + help="list of unix sockets to connect to: [name=]/path/to/socket", + ) @pytest.fixture @@ -250,23 +267,30 @@ def ensure_mysql_version(request, mysql_image, mysql_tag): @pytest.fixture(scope='session') def mysql_server(mysql_image, mysql_tag, mysql_address): - ssl_directory = os.path.join(os.path.dirname(__file__), - 'ssl_resources', 'ssl') - ca_file = os.path.join(ssl_directory, 'ca.pem') + unix_socket = type(mysql_address) is str + + if not unix_socket: + ssl_directory = os.path.join(os.path.dirname(__file__), + 'ssl_resources', 'ssl') + ca_file = os.path.join(ssl_directory, 'ca.pem') - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - ctx.check_hostname = False - ctx.load_verify_locations(cafile=ca_file) - # ctx.verify_mode = ssl.CERT_NONE + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + ctx.check_hostname = False + ctx.load_verify_locations(cafile=ca_file) + # ctx.verify_mode = ssl.CERT_NONE server_params = { - 'host': mysql_address[0], - 'port': mysql_address[1], 'user': 'root', 'password': os.environ.get("MYSQL_ROOT_PASSWORD"), - 'ssl': ctx, } + if unix_socket: + server_params["unix_socket"] = mysql_address + else: + server_params["host"] = mysql_address[0] + server_params["port"] = mysql_address[1] + server_params["ssl"] = ctx + try: connection = pymysql.connect( db='mysql', @@ -275,21 +299,22 @@ def mysql_server(mysql_image, mysql_tag, mysql_address): **server_params) with connection.cursor() as cursor: - cursor.execute("SHOW VARIABLES LIKE '%ssl%';") + if not unix_socket: + cursor.execute("SHOW VARIABLES LIKE '%ssl%';") - result = cursor.fetchall() - result = {item['Variable_name']: - item['Value'] for item in result} + result = cursor.fetchall() + result = {item['Variable_name']: + item['Value'] for item in result} - assert result['have_ssl'] == "YES", \ - "SSL Not Enabled on MySQL" + assert result['have_ssl'] == "YES", \ + "SSL Not Enabled on MySQL" - cursor.execute("SHOW STATUS LIKE 'Ssl_version%'") + cursor.execute("SHOW STATUS LIKE 'Ssl_version%'") - result = cursor.fetchone() - # As we connected with TLS, it should start with that :D - assert result['Value'].startswith('TLS'), \ - "Not connected to the database with TLS" + result = cursor.fetchone() + # As we connected with TLS, it should start with that :D + assert result['Value'].startswith('TLS'), \ + "Not connected to the database with TLS" # Drop possibly existing old databases cursor.execute('DROP DATABASE IF EXISTS test_pymysql;') diff --git a/tests/fixtures/my.cnf.tmpl b/tests/fixtures/my.cnf.tcp.tmpl similarity index 100% rename from tests/fixtures/my.cnf.tmpl rename to tests/fixtures/my.cnf.tcp.tmpl diff --git a/tests/fixtures/my.cnf.unix.tmpl b/tests/fixtures/my.cnf.unix.tmpl new file mode 100644 index 00000000..2aad4432 --- /dev/null +++ b/tests/fixtures/my.cnf.unix.tmpl @@ -0,0 +1,16 @@ +# +# The MySQL database server configuration file. +# +[client] +user = {user} +socket = {unix_socket} +password = {password} +database = {db} +default-character-set = utf8 + +[client_with_unix_socket] +user = {user} +socket = {unix_socket} +password = {password} +database = {db} +default-character-set = utf8 diff --git a/tests/sa/test_sa_compiled_cache.py b/tests/sa/test_sa_compiled_cache.py index e8c0f5f2..38906551 100644 --- a/tests/sa/test_sa_compiled_cache.py +++ b/tests/sa/test_sa_compiled_cache.py @@ -15,12 +15,19 @@ @pytest.fixture() def make_engine(mysql_params, connection): async def _make_engine(**kwargs): + if "unix_socket" in mysql_params: + conn_args = {"unix_socket": mysql_params["unix_socket"]} + else: + conn_args = { + "host": mysql_params['host'], + "port": mysql_params['port'], + } + return (await sa.create_engine(db=mysql_params['db'], user=mysql_params['user'], password=mysql_params['password'], - host=mysql_params['host'], - port=mysql_params['port'], minsize=10, + **conn_args, **kwargs)) return _make_engine diff --git a/tests/sa/test_sa_default.py b/tests/sa/test_sa_default.py index 42c34f5b..e5f270ec 100644 --- a/tests/sa/test_sa_default.py +++ b/tests/sa/test_sa_default.py @@ -22,12 +22,19 @@ @pytest.fixture() def make_engine(mysql_params, connection): async def _make_engine(**kwargs): + if "unix_socket" in mysql_params: + conn_args = {"unix_socket": mysql_params["unix_socket"]} + else: + conn_args = { + "host": mysql_params['host'], + "port": mysql_params['port'], + } + return (await sa.create_engine(db=mysql_params['db'], user=mysql_params['user'], password=mysql_params['password'], - host=mysql_params['host'], - port=mysql_params['port'], minsize=10, + **conn_args, **kwargs)) return _make_engine diff --git a/tests/sa/test_sa_engine.py b/tests/sa/test_sa_engine.py index e514260d..ed74a96d 100644 --- a/tests/sa/test_sa_engine.py +++ b/tests/sa/test_sa_engine.py @@ -15,12 +15,19 @@ @pytest.fixture() def make_engine(connection, mysql_params): async def _make_engine(**kwargs): + if "unix_socket" in mysql_params: + conn_args = {"unix_socket": mysql_params["unix_socket"]} + else: + conn_args = { + "host": mysql_params['host'], + "port": mysql_params['port'], + } + return (await sa.create_engine(db=mysql_params['db'], user=mysql_params['user'], password=mysql_params['password'], - host=mysql_params['host'], - port=mysql_params['port'], minsize=10, + **conn_args, **kwargs)) return _make_engine diff --git a/tests/ssl_resources/socket.cnf b/tests/ssl_resources/socket.cnf new file mode 100644 index 00000000..32100e93 --- /dev/null +++ b/tests/ssl_resources/socket.cnf @@ -0,0 +1,2 @@ +[mysqld] +socket = /socket-mount/mysql.sock diff --git a/tests/test_connection.py b/tests/test_connection.py index 075039d0..af6788f3 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -10,7 +10,13 @@ @pytest.fixture() def fill_my_cnf(mysql_params): tests_root = os.path.abspath(os.path.dirname(__file__)) - path1 = os.path.join(tests_root, 'fixtures/my.cnf.tmpl') + + if "unix_socket" in mysql_params: + tmpl_path = "fixtures/my.cnf.unix.tmpl" + else: + tmpl_path = "fixtures/my.cnf.tcp.tmpl" + + path1 = os.path.join(tests_root, tmpl_path) path2 = os.path.join(tests_root, 'fixtures/my.cnf') with open(path1) as f1: tmpl = f1.read() @@ -31,8 +37,11 @@ async def test_config_file(fill_my_cnf, connection_creator, mysql_params): path = os.path.join(tests_root, 'fixtures/my.cnf') conn = await connection_creator(read_default_file=path) - assert conn.host == mysql_params['host'] - assert conn.port == mysql_params['port'] + if "unix_socket" in mysql_params: + assert conn.unix_socket == mysql_params["unix_socket"] + else: + assert conn.host == mysql_params['host'] + assert conn.port == mysql_params['port'] assert conn.user, mysql_params['user'] # make sure connection is working @@ -167,12 +176,15 @@ async def test_connection_gone_away(connection_creator): @pytest.mark.run_loop -async def test_connection_info_methods(connection_creator): +async def test_connection_info_methods(connection_creator, mysql_params): conn = await connection_creator() # trhead id is int assert isinstance(conn.thread_id(), int) assert conn.character_set_name() in ('latin1', 'utf8mb4') - assert str(conn.port) in conn.get_host_info() + if "unix_socket" in mysql_params: + assert mysql_params["unix_socket"] in conn.get_host_info() + else: + assert str(conn.port) in conn.get_host_info() assert isinstance(conn.get_server_info(), str) # protocol id is int assert isinstance(conn.get_proto_info(), int) @@ -200,8 +212,11 @@ async def test_connection_ping(connection_creator): @pytest.mark.run_loop async def test_connection_properties(connection_creator, mysql_params): conn = await connection_creator() - assert conn.host == mysql_params['host'] - assert conn.port == mysql_params['port'] + if "unix_socket" in mysql_params: + assert conn.unix_socket == mysql_params["unix_socket"] + else: + assert conn.host == mysql_params['host'] + assert conn.port == mysql_params['port'] assert conn.user == mysql_params['user'] assert conn.db == mysql_params['db'] assert conn.echo is False diff --git a/tests/test_issues.py b/tests/test_issues.py index 942bc8ed..c25e292f 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -184,7 +184,7 @@ async def test_issue_17(connection, connection_creator, mysql_params): async def test_issue_34(connection_creator): try: await connection_creator(host="localhost", port=1237, - user="root") + user="root", unix_socket=None) pytest.fail() except aiomysql.OperationalError as e: assert 2003 == e.args[0] diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py index eb57ec3d..0789d162 100644 --- a/tests/test_sha_connection.py +++ b/tests/test_sha_connection.py @@ -39,6 +39,13 @@ async def test_sha256_nopw(mysql_server, loop): @pytest.mark.mysql_version('mysql', '8.0') @pytest.mark.run_loop async def test_sha256_pw(mysql_server, loop): + # https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html + # Unlike caching_sha2_password, the sha256_password plugin does not treat + # shared-memory connections as secure, even though share-memory transport + # is secure by default. + if "unix_socket" in mysql_server['conn_params']: + pytest.skip("sha256_password is not supported on unix sockets") + connection_data = copy.copy(mysql_server['conn_params']) connection_data['user'] = 'user_sha256' connection_data['password'] = 'pass_sha256' diff --git a/tests/test_ssl.py b/tests/test_ssl.py index ff1ea740..140c164f 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -4,7 +4,10 @@ @pytest.mark.run_loop -async def test_tls_connect(mysql_server, loop): +async def test_tls_connect(mysql_server, loop, mysql_params): + if "unix_socket" in mysql_params: + pytest.skip("TLS is not supported on unix sockets") + async with create_pool(**mysql_server['conn_params'], loop=loop) as pool: async with pool.get() as conn: @@ -32,7 +35,10 @@ async def test_tls_connect(mysql_server, loop): # MySQL will get you to renegotiate if sent a cleartext password @pytest.mark.run_loop -async def test_auth_plugin_renegotiation(mysql_server, loop): +async def test_auth_plugin_renegotiation(mysql_server, loop, mysql_params): + if "unix_socket" in mysql_params: + pytest.skip("TLS is not supported on unix sockets") + async with create_pool(**mysql_server['conn_params'], auth_plugin='mysql_clear_password', loop=loop) as pool: