Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auth extension handling #586

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions gel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .abstract import (
Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor,
)
from .base_client import ConnectionInfo

from .asyncio_client import (
create_async_client,
Expand All @@ -52,6 +53,7 @@
"Cardinality",
"Client",
"ConfigMemory",
"ConnectionInfo",
"DateDuration",
"EdgeDBError",
"EdgeDBMessage",
Expand Down
33 changes: 17 additions & 16 deletions gel/ai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@


def create_rag_client(client: gel.Client, **kwargs) -> RAGClient:
client.ensure_connected()
return RAGClient(client, types.RAGOptions(**kwargs))
info = client.check_connection()
return RAGClient(info, types.RAGOptions(**kwargs))


async def create_async_rag_client(
client: gel.AsyncIOClient, **kwargs
) -> AsyncRAGClient:
await client.ensure_connected()
return AsyncRAGClient(client, types.RAGOptions(**kwargs))
info = await client.check_connection()
return AsyncRAGClient(info, types.RAGOptions(**kwargs))


class BaseRAGClient:
Expand All @@ -45,25 +45,26 @@ class BaseRAGClient:

def __init__(
self,
client: typing.Union[gel.Client, gel.AsyncIOClient],
info: gel.ConnectionInfo,
options: types.RAGOptions,
**kwargs,
):
pool = client._impl
host, port = pool._working_addr
params = pool._working_params
proto = "http" if params.tls_security == "insecure" else "https"
branch = params.branch
proto = "http" if info.params.tls_security == "insecure" else "https"
branch = info.params.branch
self.options = options
self.context = types.QueryContext(**kwargs)
args = dict(
base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai",
verify=params.ssl_ctx,
base_url=(
f"{proto}://{info.host}:{info.port}/branch/{branch}/ext/ai"
),
verify=info.params.ssl_ctx,
)
if params.password is not None:
args["auth"] = (params.user, params.password)
elif params.secret_key is not None:
args["headers"] = {"Authorization": f"Bearer {params.secret_key}"}
if info.params.password is not None:
args["auth"] = (info.params.user, info.params.password)
elif info.params.secret_key is not None:
args["headers"] = {
"Authorization": f"Bearer {info.params.secret_key}"
}
self._init_client(**args)

def _init_client(self, **kwargs):
Expand Down
5 changes: 4 additions & 1 deletion gel/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,11 @@ class AsyncIOClient(base_client.BaseClient, abstract.AsyncIOExecutor):
__slots__ = ()
_impl_class = _AsyncIOPoolImpl

async def check_connection(self) -> base_client.ConnectionInfo:
return await self._impl.ensure_connected()

async def ensure_connected(self):
await self._impl.ensure_connected()
await self.check_connection()
return self

async def aclose(self):
Expand Down
30 changes: 30 additions & 0 deletions gel/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2025-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from . import email_password
from .token_data import TokenData
from .pkce import PKCE, generate_pkce, AsyncPKCE, generate_async_pkce

__all__ = [
"email_password",
"TokenData",
"PKCE",
"generate_pkce",
"AsyncPKCE",
"generate_async_pkce",
]
Loading
Loading