@@ -27,6 +27,51 @@ def pytest_generate_tests(metafunc):
2727 loop_type = ['asyncio' , 'uvloop' ] if uvloop else ['asyncio' ]
2828 metafunc .parametrize ("loop_type" , loop_type )
2929
30+ if "mysql_address" in metafunc .fixturenames :
31+ mysql_addresses = []
32+ ids = []
33+
34+ opt_mysql_unix_socket = \
35+ list (metafunc .config .getoption ("mysql_unix_socket" ))
36+ for i in range (len (opt_mysql_unix_socket )):
37+ if "=" in opt_mysql_unix_socket [i ]:
38+ label , path = opt_mysql_unix_socket [i ].rsplit ("=" , 1 )
39+ mysql_addresses .append (path )
40+ ids .append (label )
41+ else :
42+ mysql_addresses .append (opt_mysql_unix_socket [i ])
43+ ids .append ("unix{}" .format (i ))
44+
45+ opt_mysql_address = list (metafunc .config .getoption ("mysql_address" ))
46+ for i in range (len (opt_mysql_address )):
47+ if "=" in opt_mysql_address [i ]:
48+ label , addr = opt_mysql_address [i ].rsplit ("=" , 1 )
49+ ids .append (label )
50+ else :
51+ addr = opt_mysql_address [i ]
52+ ids .append ("tcp{}" .format (i ))
53+
54+ if ":" in addr :
55+ addr = addr .rsplit (":" , 1 )
56+ mysql_addresses .append ((addr [0 ], int (addr [1 ])))
57+ else :
58+ mysql_addresses .append ((addr , 3306 ))
59+
60+ # default to connecting to localhost
61+ if len (mysql_addresses ) == 0 :
62+ mysql_addresses = [("127.0.0.1" , 3306 )]
63+ ids = ["tcp-local" ]
64+
65+ assert len (mysql_addresses ) == len (set (mysql_addresses ))
66+ assert len (ids ) == len (set (ids ))
67+ assert len (mysql_addresses ) == len (ids )
68+
69+ metafunc .parametrize ("mysql_address" ,
70+ mysql_addresses ,
71+ ids = ids ,
72+ scope = "session" ,
73+ )
74+
3075
3176# This is here unless someone fixes the generate_tests bit
3277@pytest .fixture (scope = 'session' )
@@ -101,6 +146,21 @@ def pytest_configure(config):
101146 )
102147
103148
149+ def pytest_addoption (parser ):
150+ parser .addoption (
151+ "--mysql-address" ,
152+ action = "append" ,
153+ default = [],
154+ help = "list of addresses to connect to" ,
155+ )
156+ parser .addoption (
157+ "--mysql-unix-socket" ,
158+ action = "append" ,
159+ default = [],
160+ help = "list of unix sockets to connect to" ,
161+ )
162+
163+
104164@pytest .fixture
105165def mysql_params (mysql_server ):
106166 params = {** mysql_server ['conn_params' ],
@@ -205,24 +265,31 @@ def ensure_mysql_version(request, mysql_image, mysql_tag):
205265
206266
207267@pytest .fixture (scope = 'session' )
208- def mysql_server (mysql_image , mysql_tag ):
209- ssl_directory = os .path .join (os .path .dirname (__file__ ),
210- 'ssl_resources' , 'ssl' )
211- ca_file = os .path .join (ssl_directory , 'ca.pem' )
268+ def mysql_server (mysql_image , mysql_tag , mysql_address ):
269+ unix_socket = type (mysql_address ) is str
270+
271+ if not unix_socket :
272+ ssl_directory = os .path .join (os .path .dirname (__file__ ),
273+ 'ssl_resources' , 'ssl' )
274+ ca_file = os .path .join (ssl_directory , 'ca.pem' )
212275
213- ctx = ssl .SSLContext (ssl .PROTOCOL_TLSv1_2 )
214- ctx .check_hostname = False
215- ctx .load_verify_locations (cafile = ca_file )
216- # ctx.verify_mode = ssl.CERT_NONE
276+ ctx = ssl .SSLContext (ssl .PROTOCOL_TLSv1_2 )
277+ ctx .check_hostname = False
278+ ctx .load_verify_locations (cafile = ca_file )
279+ # ctx.verify_mode = ssl.CERT_NONE
217280
218281 server_params = {
219- 'host' : '127.0.0.1' ,
220- 'port' : 3306 ,
221282 'user' : 'root' ,
222283 'password' : os .environ .get ("MYSQL_ROOT_PASSWORD" ),
223- 'ssl' : ctx ,
224284 }
225285
286+ if unix_socket :
287+ server_params ["unix_socket" ] = mysql_address
288+ else :
289+ server_params ["host" ] = mysql_address [0 ]
290+ server_params ["port" ] = mysql_address [1 ]
291+ server_params ["ssl" ] = ctx
292+
226293 try :
227294 connection = pymysql .connect (
228295 db = 'mysql' ,
@@ -231,21 +298,22 @@ def mysql_server(mysql_image, mysql_tag):
231298 ** server_params )
232299
233300 with connection .cursor () as cursor :
234- cursor .execute ("SHOW VARIABLES LIKE '%ssl%';" )
301+ if not unix_socket :
302+ cursor .execute ("SHOW VARIABLES LIKE '%ssl%';" )
235303
236- result = cursor .fetchall ()
237- result = {item ['Variable_name' ]:
238- item ['Value' ] for item in result }
304+ result = cursor .fetchall ()
305+ result = {item ['Variable_name' ]:
306+ item ['Value' ] for item in result }
239307
240- assert result ['have_ssl' ] == "YES" , \
241- "SSL Not Enabled on MySQL"
308+ assert result ['have_ssl' ] == "YES" , \
309+ "SSL Not Enabled on MySQL"
242310
243- cursor .execute ("SHOW STATUS LIKE 'Ssl_version%'" )
311+ cursor .execute ("SHOW STATUS LIKE 'Ssl_version%'" )
244312
245- result = cursor .fetchone ()
246- # As we connected with TLS, it should start with that :D
247- assert result ['Value' ].startswith ('TLS' ), \
248- "Not connected to the database with TLS"
313+ result = cursor .fetchone ()
314+ # As we connected with TLS, it should start with that :D
315+ assert result ['Value' ].startswith ('TLS' ), \
316+ "Not connected to the database with TLS"
249317
250318 # Drop possibly existing old databases
251319 cursor .execute ('DROP DATABASE IF EXISTS test_pymysql;' )
0 commit comments