@@ -27,6 +27,48 @@ def pytest_generate_tests(metafunc):
2727 loop_type = ['asyncio' , 'uvloop' ] if uvloop else ['asyncio' ]
2828 metafunc .parametrize ("loop_type" , loop_type )
2929
30+ if "mysql_address" in metafunc .fixturenames :
31+ mysql_addresses = []
32+ ids = []
33+
34+ opt_mysql_unix_socket = \
35+ list (metafunc .config .getoption ("mysql_unix_socket" ))
36+ for i in range (len (opt_mysql_unix_socket )):
37+ if "=" in opt_mysql_unix_socket [i ]:
38+ label , path = opt_mysql_unix_socket [i ].rsplit ("=" , 1 )
39+ mysql_addresses .append (path )
40+ ids .append (label )
41+ else :
42+ mysql_addresses .append (opt_mysql_unix_socket [i ])
43+ ids .append ("socket{}" .format (i ))
44+
45+ opt_mysql_address = list (metafunc .config .getoption ("mysql_address" ))
46+ for i in range (len (opt_mysql_address )):
47+ if "=" in opt_mysql_address [i ]:
48+ label , addr = opt_mysql_unix_socket [i ].rsplit ("=" , 1 )
49+ ids .append (label )
50+ else :
51+ addr = opt_mysql_address [i ]
52+ ids .append ("socket{}" .format (i ))
53+
54+ if ":" in addr :
55+ addr = addr .rsplit (":" , 1 )
56+ mysql_addresses .append ((addr [0 ], int (addr [1 ])))
57+ else :
58+ mysql_addresses .append ((addr , 3306 ))
59+ ids .append ("address{}" .format (i ))
60+
61+ # default to connecting to localhost
62+ if len (mysql_addresses ) == 0 :
63+ mysql_addresses = [("127.0.0.1" , 3306 )]
64+ ids = ["tcp-local" ]
65+
66+ metafunc .parametrize ("mysql_address" ,
67+ mysql_addresses ,
68+ ids = ids ,
69+ scope = "session" ,
70+ )
71+
3072
3173# This is here unless someone fixes the generate_tests bit
3274@pytest .fixture (scope = 'session' )
@@ -101,6 +143,21 @@ def pytest_configure(config):
101143 )
102144
103145
146+ def pytest_addoption (parser ):
147+ parser .addoption (
148+ "--mysql-address" ,
149+ action = "append" ,
150+ default = [],
151+ help = "list of addresses to connect to" ,
152+ )
153+ parser .addoption (
154+ "--mysql-unix-socket" ,
155+ action = "append" ,
156+ default = [],
157+ help = "list of unix sockets to connect to" ,
158+ )
159+
160+
104161@pytest .fixture
105162def mysql_params (mysql_server ):
106163 params = {** mysql_server ['conn_params' ],
@@ -209,24 +266,31 @@ def ensure_mysql_version(request, mysql_image, mysql_tag):
209266
210267
211268@pytest .fixture (scope = 'session' )
212- def mysql_server (mysql_image , mysql_tag ):
213- ssl_directory = os .path .join (os .path .dirname (__file__ ),
214- 'ssl_resources' , 'ssl' )
215- ca_file = os .path .join (ssl_directory , 'ca.pem' )
269+ def mysql_server (mysql_image , mysql_tag , mysql_address ):
270+ unix_socket = type (mysql_address ) is str
271+
272+ if not unix_socket :
273+ ssl_directory = os .path .join (os .path .dirname (__file__ ),
274+ 'ssl_resources' , 'ssl' )
275+ ca_file = os .path .join (ssl_directory , 'ca.pem' )
216276
217- ctx = ssl .SSLContext (ssl .PROTOCOL_TLSv1_2 )
218- ctx .check_hostname = False
219- ctx .load_verify_locations (cafile = ca_file )
220- # ctx.verify_mode = ssl.CERT_NONE
277+ ctx = ssl .SSLContext (ssl .PROTOCOL_TLSv1_2 )
278+ ctx .check_hostname = False
279+ ctx .load_verify_locations (cafile = ca_file )
280+ # ctx.verify_mode = ssl.CERT_NONE
221281
222282 server_params = {
223- 'host' : '127.0.0.1' ,
224- 'port' : 3306 ,
225283 'user' : 'root' ,
226284 'password' : os .environ .get ("MYSQL_ROOT_PASSWORD" ),
227- 'ssl' : ctx ,
228285 }
229286
287+ if unix_socket :
288+ server_params ["unix_socket" ] = mysql_address
289+ else :
290+ server_params ["host" ] = mysql_address [0 ]
291+ server_params ["port" ] = mysql_address [1 ]
292+ server_params ["ssl" ] = ctx
293+
230294 try :
231295 connection = pymysql .connect (
232296 db = 'mysql' ,
@@ -235,21 +299,22 @@ def mysql_server(mysql_image, mysql_tag):
235299 ** server_params )
236300
237301 with connection .cursor () as cursor :
238- cursor .execute ("SHOW VARIABLES LIKE '%ssl%';" )
302+ if not unix_socket :
303+ cursor .execute ("SHOW VARIABLES LIKE '%ssl%';" )
239304
240- result = cursor .fetchall ()
241- result = {item ['Variable_name' ]:
242- item ['Value' ] for item in result }
305+ result = cursor .fetchall ()
306+ result = {item ['Variable_name' ]:
307+ item ['Value' ] for item in result }
243308
244- assert result ['have_ssl' ] == "YES" , \
245- "SSL Not Enabled on MySQL"
309+ assert result ['have_ssl' ] == "YES" , \
310+ "SSL Not Enabled on MySQL"
246311
247- cursor .execute ("SHOW STATUS LIKE 'Ssl_version%'" )
312+ cursor .execute ("SHOW STATUS LIKE 'Ssl_version%'" )
248313
249- result = cursor .fetchone ()
250- # As we connected with TLS, it should start with that :D
251- assert result ['Value' ].startswith ('TLS' ), \
252- "Not connected to the database with TLS"
314+ result = cursor .fetchone ()
315+ # As we connected with TLS, it should start with that :D
316+ assert result ['Value' ].startswith ('TLS' ), \
317+ "Not connected to the database with TLS"
253318
254319 # Drop possibly existing old databases
255320 cursor .execute ('DROP DATABASE IF EXISTS test_pymysql;' )
0 commit comments