@@ -27,6 +27,48 @@ def pytest_generate_tests(metafunc):
2727 loop_type = ['asyncio' , 'uvloop' ] if uvloop else ['asyncio' ]
2828 metafunc .parametrize ("loop_type" , loop_type )
2929
30+ if "mysql_address" in metafunc .fixturenames :
31+ mysql_addresses = []
32+ ids = []
33+
34+ opt_mysql_unix_socket = \
35+ list (metafunc .config .getoption ("mysql_unix_socket" ))
36+ for i in range (len (opt_mysql_unix_socket )):
37+ if "=" in opt_mysql_unix_socket [i ]:
38+ label , path = opt_mysql_unix_socket [i ].rsplit ("=" , 1 )
39+ mysql_addresses .append (path )
40+ ids .append (label )
41+ else :
42+ mysql_addresses .append (opt_mysql_unix_socket [i ])
43+ ids .append ("socket{}" .format (i ))
44+
45+ opt_mysql_address = list (metafunc .config .getoption ("mysql_address" ))
46+ for i in range (len (opt_mysql_address )):
47+ if "=" in opt_mysql_address [i ]:
48+ label , addr = opt_mysql_unix_socket [i ].rsplit ("=" , 1 )
49+ ids .append (label )
50+ else :
51+ addr = opt_mysql_address [i ]
52+ ids .append ("socket{}" .format (i ))
53+
54+ if ":" in addr :
55+ addr = addr .rsplit (":" , 1 )
56+ mysql_addresses .append ((addr [0 ], int (addr [1 ])))
57+ else :
58+ mysql_addresses .append ((addr , 3306 ))
59+ ids .append ("address{}" .format (i ))
60+
61+ # default to connecting to localhost
62+ if len (mysql_addresses ) == 0 :
63+ mysql_addresses = [("127.0.0.1" , 3306 )]
64+ ids = ["tcp-local" ]
65+
66+ metafunc .parametrize ("mysql_address" ,
67+ mysql_addresses ,
68+ ids = ids ,
69+ scope = "session" ,
70+ )
71+
3072
3173# This is here unless someone fixes the generate_tests bit
3274@pytest .fixture (scope = 'session' )
@@ -101,6 +143,21 @@ def pytest_configure(config):
101143 )
102144
103145
146+ def pytest_addoption (parser ):
147+ parser .addoption (
148+ "--mysql-address" ,
149+ action = "append" ,
150+ default = [],
151+ help = "list of addresses to connect to" ,
152+ )
153+ parser .addoption (
154+ "--mysql-unix-socket" ,
155+ action = "append" ,
156+ default = [],
157+ help = "list of unix sockets to connect to" ,
158+ )
159+
160+
104161@pytest .fixture
105162def mysql_params (mysql_server ):
106163 params = {** mysql_server ['conn_params' ],
@@ -205,24 +262,31 @@ def ensure_mysql_version(request, mysql_image, mysql_tag):
205262
206263
207264@pytest .fixture (scope = 'session' )
208- def mysql_server (mysql_image , mysql_tag ):
209- ssl_directory = os .path .join (os .path .dirname (__file__ ),
210- 'ssl_resources' , 'ssl' )
211- ca_file = os .path .join (ssl_directory , 'ca.pem' )
265+ def mysql_server (mysql_image , mysql_tag , mysql_address ):
266+ unix_socket = type (mysql_address ) is str
267+
268+ if not unix_socket :
269+ ssl_directory = os .path .join (os .path .dirname (__file__ ),
270+ 'ssl_resources' , 'ssl' )
271+ ca_file = os .path .join (ssl_directory , 'ca.pem' )
212272
213- ctx = ssl .SSLContext (ssl .PROTOCOL_TLSv1_2 )
214- ctx .check_hostname = False
215- ctx .load_verify_locations (cafile = ca_file )
216- # ctx.verify_mode = ssl.CERT_NONE
273+ ctx = ssl .SSLContext (ssl .PROTOCOL_TLSv1_2 )
274+ ctx .check_hostname = False
275+ ctx .load_verify_locations (cafile = ca_file )
276+ # ctx.verify_mode = ssl.CERT_NONE
217277
218278 server_params = {
219- 'host' : '127.0.0.1' ,
220- 'port' : 3306 ,
221279 'user' : 'root' ,
222280 'password' : os .environ .get ("MYSQL_ROOT_PASSWORD" ),
223- 'ssl' : ctx ,
224281 }
225282
283+ if unix_socket :
284+ server_params ["unix_socket" ] = mysql_address
285+ else :
286+ server_params ["host" ] = mysql_address [0 ]
287+ server_params ["port" ] = mysql_address [1 ]
288+ server_params ["ssl" ] = ctx
289+
226290 try :
227291 connection = pymysql .connect (
228292 db = 'mysql' ,
@@ -231,21 +295,22 @@ def mysql_server(mysql_image, mysql_tag):
231295 ** server_params )
232296
233297 with connection .cursor () as cursor :
234- cursor .execute ("SHOW VARIABLES LIKE '%ssl%';" )
298+ if not unix_socket :
299+ cursor .execute ("SHOW VARIABLES LIKE '%ssl%';" )
235300
236- result = cursor .fetchall ()
237- result = {item ['Variable_name' ]:
238- item ['Value' ] for item in result }
301+ result = cursor .fetchall ()
302+ result = {item ['Variable_name' ]:
303+ item ['Value' ] for item in result }
239304
240- assert result ['have_ssl' ] == "YES" , \
241- "SSL Not Enabled on MySQL"
305+ assert result ['have_ssl' ] == "YES" , \
306+ "SSL Not Enabled on MySQL"
242307
243- cursor .execute ("SHOW STATUS LIKE 'Ssl_version%'" )
308+ cursor .execute ("SHOW STATUS LIKE 'Ssl_version%'" )
244309
245- result = cursor .fetchone ()
246- # As we connected with TLS, it should start with that :D
247- assert result ['Value' ].startswith ('TLS' ), \
248- "Not connected to the database with TLS"
310+ result = cursor .fetchone ()
311+ # As we connected with TLS, it should start with that :D
312+ assert result ['Value' ].startswith ('TLS' ), \
313+ "Not connected to the database with TLS"
249314
250315 # Drop possibly existing old databases
251316 cursor .execute ('DROP DATABASE IF EXISTS test_pymysql;' )
0 commit comments