Skip to content

Commit 3ee19ba

Browse files
d1mansonelprans
andauthored
Add connect_fn kwarg to Pool to better support GCP's CloudSQL (#1170)
Co-authored-by: Elvis Pranskevichus <[email protected]>
1 parent 73f2209 commit 3ee19ba

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

asyncpg/_testbase/__init__.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def create_pool(dsn=None, *,
274274
max_size=10,
275275
max_queries=50000,
276276
max_inactive_connection_lifetime=60.0,
277+
connect=None,
277278
setup=None,
278279
init=None,
279280
loop=None,
@@ -283,12 +284,18 @@ def create_pool(dsn=None, *,
283284
**connect_kwargs):
284285
return pool_class(
285286
dsn,
286-
min_size=min_size, max_size=max_size,
287-
max_queries=max_queries, loop=loop, setup=setup, init=init,
287+
min_size=min_size,
288+
max_size=max_size,
289+
max_queries=max_queries,
290+
loop=loop,
291+
connect=connect,
292+
setup=setup,
293+
init=init,
288294
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
289295
connection_class=connection_class,
290296
record_class=record_class,
291-
**connect_kwargs)
297+
**connect_kwargs,
298+
)
292299

293300

294301
class ClusterTestCase(TestCase):

asyncpg/pool.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class Pool:
313313

314314
__slots__ = (
315315
'_queue', '_loop', '_minsize', '_maxsize',
316-
'_init', '_connect_args', '_connect_kwargs',
316+
'_init', '_connect', '_connect_args', '_connect_kwargs',
317317
'_holders', '_initialized', '_initializing', '_closing',
318318
'_closed', '_connection_class', '_record_class', '_generation',
319319
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
@@ -324,8 +324,9 @@ def __init__(self, *connect_args,
324324
max_size,
325325
max_queries,
326326
max_inactive_connection_lifetime,
327-
setup,
328-
init,
327+
connect=None,
328+
setup=None,
329+
init=None,
329330
loop,
330331
connection_class,
331332
record_class,
@@ -385,11 +386,14 @@ def __init__(self, *connect_args,
385386
self._closing = False
386387
self._closed = False
387388
self._generation = 0
388-
self._init = init
389+
390+
self._connect = connect if connect is not None else connection.connect
389391
self._connect_args = connect_args
390392
self._connect_kwargs = connect_kwargs
391393

392394
self._setup = setup
395+
self._init = init
396+
393397
self._max_queries = max_queries
394398
self._max_inactive_connection_lifetime = \
395399
max_inactive_connection_lifetime
@@ -503,13 +507,25 @@ def set_connect_args(self, dsn=None, **connect_kwargs):
503507
self._connect_kwargs = connect_kwargs
504508

505509
async def _get_new_connection(self):
506-
con = await connection.connect(
510+
con = await self._connect(
507511
*self._connect_args,
508512
loop=self._loop,
509513
connection_class=self._connection_class,
510514
record_class=self._record_class,
511515
**self._connect_kwargs,
512516
)
517+
if not isinstance(con, self._connection_class):
518+
good = self._connection_class
519+
good_n = f'{good.__module__}.{good.__name__}'
520+
bad = type(con)
521+
if bad.__module__ == "builtins":
522+
bad_n = bad.__name__
523+
else:
524+
bad_n = f'{bad.__module__}.{bad.__name__}'
525+
raise exceptions.InterfaceError(
526+
"expected pool connect callback to return an instance of "
527+
f"'{good_n}', got " f"'{bad_n}'"
528+
)
513529

514530
if self._init is not None:
515531
try:
@@ -1017,6 +1033,7 @@ def create_pool(dsn=None, *,
10171033
max_size=10,
10181034
max_queries=50000,
10191035
max_inactive_connection_lifetime=300.0,
1036+
connect=None,
10201037
setup=None,
10211038
init=None,
10221039
loop=None,
@@ -1099,6 +1116,13 @@ def create_pool(dsn=None, *,
10991116
Number of seconds after which inactive connections in the
11001117
pool will be closed. Pass ``0`` to disable this mechanism.
11011118
1119+
:param coroutine connect:
1120+
A coroutine that is called instead of
1121+
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
1122+
new connection. Must return an instance of type specified by
1123+
*connection_class* or :class:`~asyncpg.connection.Connection` if
1124+
*connection_class* was not specified.
1125+
11021126
:param coroutine setup:
11031127
A coroutine to prepare a connection right before it is returned
11041128
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
@@ -1139,12 +1163,21 @@ def create_pool(dsn=None, *,
11391163
11401164
.. versionchanged:: 0.22.0
11411165
Added the *record_class* parameter.
1166+
1167+
.. versionchanged:: 0.30.0
1168+
Added the *connect* parameter.
11421169
"""
11431170
return Pool(
11441171
dsn,
11451172
connection_class=connection_class,
11461173
record_class=record_class,
1147-
min_size=min_size, max_size=max_size,
1148-
max_queries=max_queries, loop=loop, setup=setup, init=init,
1174+
min_size=min_size,
1175+
max_size=max_size,
1176+
max_queries=max_queries,
1177+
loop=loop,
1178+
connect=connect,
1179+
setup=setup,
1180+
init=init,
11491181
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
1150-
**connect_kwargs)
1182+
**connect_kwargs,
1183+
)

tests/test_pool.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ async def setup(con):
136136

137137
async def test_pool_07(self):
138138
cons = set()
139+
connect_called = 0
140+
141+
async def connect(*args, **kwargs):
142+
nonlocal connect_called
143+
connect_called += 1
144+
return await pg_connection.connect(*args, **kwargs)
139145

140146
async def setup(con):
141147
if con._con not in cons: # `con` is `PoolConnectionProxy`.
@@ -152,13 +158,26 @@ async def user(pool):
152158
raise RuntimeError('init was not called')
153159

154160
async with self.create_pool(database='postgres',
155-
min_size=2, max_size=5,
161+
min_size=2,
162+
max_size=5,
163+
connect=connect,
156164
init=init,
157165
setup=setup) as pool:
158166
users = asyncio.gather(*[user(pool) for _ in range(10)])
159167
await users
160168

161169
self.assertEqual(len(cons), 5)
170+
self.assertEqual(connect_called, 5)
171+
172+
async def bad_connect(*args, **kwargs):
173+
return 1
174+
175+
with self.assertRaisesRegex(
176+
asyncpg.InterfaceError,
177+
"expected pool connect callback to return an instance of "
178+
"'asyncpg\\.connection\\.Connection', got 'int'"
179+
):
180+
await self.create_pool(database='postgres', connect=bad_connect)
162181

163182
async def test_pool_08(self):
164183
pool = await self.create_pool(database='postgres',

0 commit comments

Comments
 (0)