Skip to content

Commit a4b2d16

Browse files
committed
Extract client.check_connection()
1 parent 7e4853f commit a4b2d16

7 files changed

+72
-42
lines changed

gel/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .abstract import (
3131
Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor,
3232
)
33+
from .base_client import ConnectionInfo
3334

3435
from .asyncio_client import (
3536
create_async_client,
@@ -52,6 +53,7 @@
5253
"Cardinality",
5354
"Client",
5455
"ConfigMemory",
56+
"ConnectionInfo",
5557
"DateDuration",
5658
"EdgeDBError",
5759
"EdgeDBMessage",

gel/ai/core.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@
2727

2828

2929
def create_rag_client(client: gel.Client, **kwargs) -> RAGClient:
30-
client.ensure_connected()
31-
return RAGClient(client, types.RAGOptions(**kwargs))
30+
info = client.check_connection()
31+
return RAGClient(info, types.RAGOptions(**kwargs))
3232

3333

3434
async def create_async_rag_client(
3535
client: gel.AsyncIOClient, **kwargs
3636
) -> AsyncRAGClient:
37-
await client.ensure_connected()
38-
return AsyncRAGClient(client, types.RAGOptions(**kwargs))
37+
info = await client.check_connection()
38+
return AsyncRAGClient(info, types.RAGOptions(**kwargs))
3939

4040

4141
class BaseRAGClient:
@@ -45,25 +45,26 @@ class BaseRAGClient:
4545

4646
def __init__(
4747
self,
48-
client: typing.Union[gel.Client, gel.AsyncIOClient],
48+
info: gel.ConnectionInfo,
4949
options: types.RAGOptions,
5050
**kwargs,
5151
):
52-
pool = client._impl
53-
host, port = pool._working_addr
54-
params = pool._working_params
55-
proto = "http" if params.tls_security == "insecure" else "https"
56-
branch = params.branch
52+
proto = "http" if info.params.tls_security == "insecure" else "https"
53+
branch = info.params.branch
5754
self.options = options
5855
self.context = types.QueryContext(**kwargs)
5956
args = dict(
60-
base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai",
61-
verify=params.ssl_ctx,
57+
base_url=(
58+
f"{proto}://{info.host}:{info.port}/branch/{branch}/ext/ai"
59+
),
60+
verify=info.params.ssl_ctx,
6261
)
63-
if params.password is not None:
64-
args["auth"] = (params.user, params.password)
65-
elif params.secret_key is not None:
66-
args["headers"] = {"Authorization": f"Bearer {params.secret_key}"}
62+
if info.params.password is not None:
63+
args["auth"] = (info.params.user, info.params.password)
64+
elif info.params.secret_key is not None:
65+
args["headers"] = {
66+
"Authorization": f"Bearer {info.params.secret_key}"
67+
}
6768
self._init_client(**args)
6869

6970
def _init_client(self, **kwargs):

gel/asyncio_client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,11 @@ class AsyncIOClient(base_client.BaseClient, abstract.AsyncIOExecutor):
361361
__slots__ = ()
362362
_impl_class = _AsyncIOPoolImpl
363363

364+
async def check_connection(self) -> base_client.ConnectionInfo:
365+
return await self._impl.ensure_connected()
366+
364367
async def ensure_connected(self):
365-
await self._impl.ensure_connected()
368+
await self.check_connection()
366369
return self
367370

368371
async def aclose(self):

gel/auth/email_password.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,10 @@ async def make(
170170
verify_url: str,
171171
reset_url: str,
172172
) -> EmailPassword:
173-
await client.ensure_connected()
174-
pool = client._impl
175-
host, port = pool._working_addr
176-
params = pool._working_params
177-
proto = "http" if params.tls_security == "insecure" else "https"
178-
branch = params.branch
179-
auth_ext_url = f"{proto}://{host}:{port}/branch/{branch}/ext/auth/"
173+
info = await client.check_connection()
174+
proto = "http" if info.params.tls_security == "insecure" else "https"
175+
branch = info.params.branch
176+
auth_ext_url = f"{proto}://{info.host}:{info.port}/branch/{branch}/ext/auth/"
180177
return EmailPassword(
181178
auth_ext_url=auth_ext_url, verify_url=verify_url, reset_url=reset_url
182179
)

gel/base_client.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
import abc
21+
import dataclasses
2122
import random
2223
import time
2324
import typing
@@ -411,6 +412,14 @@ def _release(self):
411412
self._pool._queue.put_nowait(self)
412413

413414

415+
@dataclasses.dataclass
416+
class ConnectionInfo:
417+
host: str
418+
port: int
419+
params: con_utils.ResolvedConnectConfig
420+
config: con_utils.ClientConfiguration
421+
422+
414423
class BasePoolImpl(abc.ABC):
415424
__slots__ = (
416425
"_connect_args",
@@ -624,7 +633,7 @@ def expire_connections(self):
624633
"""
625634
self._generation += 1
626635

627-
async def ensure_connected(self):
636+
async def ensure_connected(self) -> ConnectionInfo:
628637
self._ensure_initialized()
629638

630639
for ch in self._holders:
@@ -635,6 +644,16 @@ async def ensure_connected(self):
635644
ch._con = None
636645
await ch.connect()
637646

647+
assert self._working_addr is not None
648+
assert self._working_config is not None
649+
assert self._working_params is not None
650+
return ConnectionInfo(
651+
host=self._working_addr[0],
652+
port=self._working_addr[1],
653+
config=self._working_config,
654+
params=self._working_params,
655+
)
656+
638657

639658
class BaseClient(abstract.BaseReadOnlyExecutor, _options._OptionsMixin):
640659
__slots__ = ("_impl", "_options")

gel/blocking_client.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@
3939
MINIMUM_PING_WAIT_TIME = datetime.timedelta(seconds=1)
4040

4141

42+
T = typing.TypeVar("T")
43+
44+
45+
def iter_coroutine(coro: typing.Coroutine[None, None, T]) -> T:
46+
try:
47+
coro.send(None)
48+
except StopIteration as ex:
49+
return ex.value
50+
finally:
51+
coro.close()
52+
53+
4254
class BlockingIOConnection(base_client.BaseConnection):
4355
__slots__ = ("_ping_wait_time",)
4456

@@ -328,7 +340,7 @@ def __enter__(self):
328340
def __exit__(self, extype, ex, tb):
329341
with self._exclusive():
330342
self._managed = False
331-
return self._client._iter_coroutine(self._exit(extype, ex))
343+
return iter_coroutine(self._exit(extype, ex))
332344

333345
async def _ensure_transaction(self):
334346
if not self._managed:
@@ -340,11 +352,11 @@ async def _ensure_transaction(self):
340352

341353
def _query(self, query_context: abstract.QueryContext):
342354
with self._exclusive():
343-
return self._client._iter_coroutine(super()._query(query_context))
355+
return iter_coroutine(super()._query(query_context))
344356

345357
def _execute(self, execute_context: abstract.ExecuteContext) -> None:
346358
with self._exclusive():
347-
self._client._iter_coroutine(super()._execute(execute_context))
359+
iter_coroutine(super()._execute(execute_context))
348360

349361
@contextlib.contextmanager
350362
def _exclusive(self):
@@ -392,22 +404,17 @@ class Client(base_client.BaseClient, abstract.Executor):
392404
__slots__ = ()
393405
_impl_class = _PoolImpl
394406

395-
def _iter_coroutine(self, coro):
396-
try:
397-
coro.send(None)
398-
except StopIteration as ex:
399-
return ex.value
400-
finally:
401-
coro.close()
402-
403407
def _query(self, query_context: abstract.QueryContext):
404-
return self._iter_coroutine(super()._query(query_context))
408+
return iter_coroutine(super()._query(query_context))
405409

406410
def _execute(self, execute_context: abstract.ExecuteContext) -> None:
407-
self._iter_coroutine(super()._execute(execute_context))
411+
iter_coroutine(super()._execute(execute_context))
412+
413+
def check_connection(self) -> base_client.ConnectionInfo:
414+
return iter_coroutine(self._impl.ensure_connected())
408415

409416
def ensure_connected(self):
410-
self._iter_coroutine(self._impl.ensure_connected())
417+
self.check_connection()
411418
return self
412419

413420
def transaction(self) -> Retry:
@@ -421,7 +428,7 @@ def close(self, timeout=None):
421428
in ``close()`` the pool will terminate by calling
422429
Client.terminate() .
423430
"""
424-
self._iter_coroutine(self._impl.close(timeout))
431+
iter_coroutine(self._impl.close(timeout))
425432

426433
def __enter__(self):
427434
return self.ensure_connected()
@@ -438,7 +445,7 @@ def _describe_query(
438445
output_format: OutputFormat = OutputFormat.BINARY,
439446
expect_one: bool = False,
440447
) -> abstract.DescribeResult:
441-
return self._iter_coroutine(self._describe(abstract.DescribeContext(
448+
return iter_coroutine(self._describe(abstract.DescribeContext(
442449
query=query,
443450
state=self._get_state(),
444451
inject_type_names=inject_type_names,

tests/test_sync_query.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from edgedb import abstract
3030
from gel import _testbase as tb
31+
from gel import blocking_client
3132
from edgedb.protocol import protocol
3233

3334

@@ -778,7 +779,7 @@ def test_json(self):
778779

779780
def test_json_elements(self):
780781
self.client.ensure_connected()
781-
result = self.client._iter_coroutine(
782+
result = blocking_client.iter_coroutine(
782783
self.client.connection.raw_query(
783784
abstract.QueryContext(
784785
query=abstract.QueryWithArgs(

0 commit comments

Comments
 (0)