diff --git a/edgedb/ai/__init__.py b/edgedb/ai/__init__.py new file mode 100644 index 00000000..96111c2b --- /dev/null +++ b/edgedb/ai/__init__.py @@ -0,0 +1,32 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext +from .core import create_ai, EdgeDBAI +from .core import create_async_ai, AsyncEdgeDBAI + +__all__ = [ + "AIOptions", + "ChatParticipantRole", + "Prompt", + "QueryContext", + "create_ai", + "EdgeDBAI", + "create_async_ai", + "AsyncEdgeDBAI", +] diff --git a/edgedb/ai/core.py b/edgedb/ai/core.py new file mode 100644 index 00000000..e7fd0700 --- /dev/null +++ b/edgedb/ai/core.py @@ -0,0 +1,174 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations +import typing + +import edgedb +import httpx +import httpx_sse + +from . import types + + +def create_ai(client: edgedb.Client, **kwargs) -> EdgeDBAI: + client.ensure_connected() + return EdgeDBAI(client, types.AIOptions(**kwargs)) + + +async def create_async_ai( + client: edgedb.AsyncIOClient, **kwargs +) -> AsyncEdgeDBAI: + await client.ensure_connected() + return AsyncEdgeDBAI(client, types.AIOptions(**kwargs)) + + +class BaseEdgeDBAI: + options: types.AIOptions + context: types.QueryContext + client_cls = NotImplemented + + def __init__( + self, + client: typing.Union[edgedb.Client, edgedb.AsyncIOClient], + options: types.AIOptions, + **kwargs, + ): + pool = client._impl + host, port = pool._working_addr + params = pool._working_params + proto = "http" if params.tls_security == "insecure" else "https" + branch = params.branch + self.options = options + self.context = types.QueryContext(**kwargs) + args = dict( + base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai", + verify=params.ssl_ctx, + ) + if params.password is not None: + args["auth"] = (params.user, params.password) + elif params.secret_key is not None: + args["headers"] = {"Authorization": f"Bearer {params.secret_key}"} + self._init_client(**args) + + def _init_client(self, **kwargs): + raise NotImplementedError + + def with_config(self, **kwargs) -> typing.Self: + cls = type(self) + rv = cls.__new__(cls) + rv.options = self.options.derive(kwargs) + rv.context = self.context + rv.client = self.client + return rv + + def with_context(self, **kwargs) -> typing.Self: + cls = type(self) + rv = cls.__new__(cls) + rv.options = self.options + rv.context = self.context.derive(kwargs) + rv.client = self.client + return rv + + +class EdgeDBAI(BaseEdgeDBAI): + client: httpx.Client + + def _init_client(self, **kwargs): + self.client = httpx.Client(**kwargs) + + def query_rag( + self, message: str, context: typing.Optional[types.QueryContext] = None + ) -> str: + if context is None: + context = self.context + resp = self.client.post( + **types.RAGRequest( + model=self.options.model, + prompt=self.options.prompt, + context=context, + query=message, + stream=False, + ).to_httpx_request() + ) + resp.raise_for_status() + return resp.json()["response"] + + def stream_rag( + self, message: str, context: typing.Optional[types.QueryContext] = None + ): + if context is None: + context = self.context + with httpx_sse.connect_sse( + self.client, + "post", + **types.RAGRequest( + model=self.options.model, + prompt=self.options.prompt, + context=context, + query=message, + stream=True, + ).to_httpx_request(), + ) as event_source: + event_source.response.raise_for_status() + for sse in event_source.iter_sse(): + yield sse.data + + +class AsyncEdgeDBAI(BaseEdgeDBAI): + client: httpx.AsyncClient + + def _init_client(self, **kwargs): + self.client = httpx.AsyncClient(**kwargs) + + async def query_rag( + self, message: str, context: typing.Optional[types.QueryContext] = None + ) -> str: + if context is None: + context = self.context + resp = await self.client.post( + **types.RAGRequest( + model=self.options.model, + prompt=self.options.prompt, + context=context, + query=message, + stream=False, + ).to_httpx_request() + ) + resp.raise_for_status() + return resp.json()["response"] + + async def stream_rag( + self, message: str, context: typing.Optional[types.QueryContext] = None + ): + if context is None: + context = self.context + async with httpx_sse.aconnect_sse( + self.client, + "post", + **types.RAGRequest( + model=self.options.model, + prompt=self.options.prompt, + context=context, + query=message, + stream=True, + ).to_httpx_request(), + ) as event_source: + event_source.response.raise_for_status() + async for sse in event_source.aiter_sse(): + yield sse.data diff --git a/edgedb/ai/types.py b/edgedb/ai/types.py new file mode 100644 index 00000000..41bf24c0 --- /dev/null +++ b/edgedb/ai/types.py @@ -0,0 +1,81 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing + +import dataclasses as dc +import enum + + +class ChatParticipantRole(enum.Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +class Custom(typing.TypedDict): + role: ChatParticipantRole + content: str + + +class Prompt: + name: typing.Optional[str] + id: typing.Optional[str] + custom: typing.Optional[typing.List[Custom]] + + +@dc.dataclass +class AIOptions: + model: str + prompt: typing.Optional[Prompt] = None + + def derive(self, kwargs): + return AIOptions(**{**dc.asdict(self), **kwargs}) + + +@dc.dataclass +class QueryContext: + query: str = "" + variables: typing.Optional[typing.Dict[str, typing.Any]] = None + globals: typing.Optional[typing.Dict[str, typing.Any]] = None + max_object_count: typing.Optional[int] = None + + def derive(self, kwargs): + return QueryContext(**{**dc.asdict(self), **kwargs}) + + +@dc.dataclass +class RAGRequest: + model: str + prompt: typing.Optional[Prompt] + context: QueryContext + query: str + stream: typing.Optional[bool] + + def to_httpx_request(self) -> typing.Dict[str, typing.Any]: + return dict( + url="/rag", + headers={ + "Content-Type": "application/json", + "Accept": ( + "text/event-stream" if self.stream else "application/json" + ), + }, + json=dc.asdict(self), + ) diff --git a/setup.py b/setup.py index b8f83148..8884fed2 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,13 @@ 'sphinx_rtd_theme~=1.0.0', ] +AI_DEPENDENCIES = [ + 'httpx~=0.27.0', + 'httpx-sse~=0.4.0', +] + EXTRA_DEPENDENCIES = { + 'ai': AI_DEPENDENCIES, 'docs': DOC_DEPENDENCIES, 'test': TEST_DEPENDENCIES, # Dependencies required to develop edgedb.