Skip to content

Commit e5aae27

Browse files
authored
Rename AI to RAGClient and add compat names (#578)
1 parent 023697a commit e5aae27

File tree

6 files changed

+69
-25
lines changed

6 files changed

+69
-25
lines changed

edgedb/ai/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,30 @@
22
TYPE_CHECKING = False
33
if TYPE_CHECKING:
44
from gel.ai import * # noqa
5+
create_ai = create_rag_client # noqa
6+
EdgeDBAI = RAGClient # noqa
7+
create_async_ai = create_async_rag_client # noqa
8+
AsyncEdgeDBAI = AsyncRAGClient # noqa
9+
AIOptions = RAGOptions # noqa
510
import gel.ai as _mod
611
import sys as _sys
712
_cur = _sys.modules['edgedb.ai']
813
for _k in vars(_mod):
914
if not _k.startswith('__') or _k in ('__all__', '__doc__'):
1015
setattr(_cur, _k, getattr(_mod, _k))
16+
_cur.create_ai = _mod.create_rag_client
17+
_cur.EdgeDBAI = _mod.RAGClient
18+
_cur.create_async_ai = _mod.create_async_rag_client
19+
_cur.AsyncEdgeDBAI = _mod.AsyncRAGClient
20+
_cur.AIOptions = _mod.RAGOptions
21+
if hasattr(_cur, '__all__'):
22+
_cur.__all__ = _cur.__all__ + [
23+
'create_ai',
24+
'EdgeDBAI',
25+
'create_async_ai',
26+
'AsyncEdgeDBAI',
27+
'AIOptions',
28+
]
1129
del _cur
1230
del _sys
1331
del _mod

gel/ai/__init__.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
# limitations under the License.
1717
#
1818

19-
from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext
20-
from .core import create_ai, EdgeDBAI
21-
from .core import create_async_ai, AsyncEdgeDBAI
19+
from .types import RAGOptions, ChatParticipantRole, Prompt, QueryContext
20+
from .core import create_rag_client, RAGClient
21+
from .core import create_async_rag_client, AsyncRAGClient
2222

2323
__all__ = [
24-
"AIOptions",
24+
"RAGOptions",
2525
"ChatParticipantRole",
2626
"Prompt",
2727
"QueryContext",
28-
"create_ai",
29-
"EdgeDBAI",
30-
"create_async_ai",
31-
"AsyncEdgeDBAI",
28+
"create_rag_client",
29+
"RAGClient",
30+
"create_async_rag_client",
31+
"AsyncRAGClient",
3232
]

gel/ai/core.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,27 @@
2626
from . import types
2727

2828

29-
def create_ai(client: gel.Client, **kwargs) -> EdgeDBAI:
29+
def create_rag_client(client: gel.Client, **kwargs) -> RAGClient:
3030
client.ensure_connected()
31-
return EdgeDBAI(client, types.AIOptions(**kwargs))
31+
return RAGClient(client, types.RAGOptions(**kwargs))
3232

3333

34-
async def create_async_ai(
34+
async def create_async_rag_client(
3535
client: gel.AsyncIOClient, **kwargs
36-
) -> AsyncEdgeDBAI:
36+
) -> AsyncRAGClient:
3737
await client.ensure_connected()
38-
return AsyncEdgeDBAI(client, types.AIOptions(**kwargs))
38+
return AsyncRAGClient(client, types.RAGOptions(**kwargs))
3939

4040

41-
class BaseEdgeDBAI:
42-
options: types.AIOptions
41+
class BaseRAGClient:
42+
options: types.RAGOptions
4343
context: types.QueryContext
4444
client_cls = NotImplemented
4545

4646
def __init__(
4747
self,
4848
client: typing.Union[gel.Client, gel.AsyncIOClient],
49-
options: types.AIOptions,
49+
options: types.RAGOptions,
5050
**kwargs,
5151
):
5252
pool = client._impl
@@ -103,7 +103,7 @@ def _make_rag_request(
103103
)
104104

105105

106-
class EdgeDBAI(BaseEdgeDBAI):
106+
class RAGClient(BaseRAGClient):
107107
client: httpx.Client
108108

109109
def _init_client(self, **kwargs):
@@ -146,7 +146,7 @@ def generate_embeddings(self, *inputs: str, model: str) -> list[float]:
146146
return resp.json()["data"][0]["embedding"]
147147

148148

149-
class AsyncEdgeDBAI(BaseEdgeDBAI):
149+
class AsyncRAGClient(BaseRAGClient):
150150
client: httpx.AsyncClient
151151

152152
def _init_client(self, **kwargs):

gel/ai/types.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ class Prompt:
4141

4242

4343
@dc.dataclass
44-
class AIOptions:
44+
class RAGOptions:
4545
model: str
4646
prompt: typing.Optional[Prompt] = None
4747

4848
def derive(self, kwargs):
49-
return AIOptions(**{**dc.asdict(self), **kwargs})
49+
return RAGOptions(**{**dc.asdict(self), **kwargs})
5050

5151

5252
@dc.dataclass

tools/gen_init.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
if __name__ == '__main__':
2525
this = pathlib.Path(__file__)
2626

27-
errors_fn = this.parent.parent / 'edgedb' / 'errors' / '__init__.py'
28-
init_fn = this.parent.parent / 'edgedb' / '__init__.py'
27+
errors_fn = this.parent.parent / 'gel' / 'errors' / '__init__.py'
28+
init_fn = this.parent.parent / 'gel' / '__init__.py'
2929

3030
with open(errors_fn, 'rt') as f:
3131
errors_txt = f.read()

tools/make_import_shims.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import os
2-
import sys
32

43
MODS = sorted(['gel', 'gel._taskgroup', 'gel._version', 'gel.abstract', 'gel.ai', 'gel.ai.core', 'gel.ai.types', 'gel.asyncio_client', 'gel.base_client', 'gel.blocking_client', 'gel.codegen', 'gel.color', 'gel.con_utils', 'gel.credentials', 'gel.datatypes', 'gel.datatypes.datatypes', 'gel.datatypes.range', 'gel.describe', 'gel.enums', 'gel.errors', 'gel.errors._base', 'gel.errors.tags', 'gel.introspect', 'gel.options', 'gel.pgproto', 'gel.pgproto.pgproto', 'gel.pgproto.types', 'gel.platform', 'gel.protocol', 'gel.protocol.asyncio_proto', 'gel.protocol.blocking_proto', 'gel.protocol.protocol', 'gel.scram', 'gel.scram.saslprep', 'gel.transaction'])
5-
4+
COMPAT = {
5+
'gel.ai': {
6+
'create_ai': 'create_rag_client',
7+
'EdgeDBAI': 'RAGClient',
8+
'create_async_ai': 'create_async_rag_client',
9+
'AsyncEdgeDBAI': 'AsyncRAGClient',
10+
'AIOptions': 'RAGOptions',
11+
},
12+
}
613

714

815
def main():
@@ -12,7 +19,10 @@ def main():
1219
nmod = 'edgedb' + mod[len('gel'):]
1320
slash_name = nmod.replace('.', '/')
1421
if is_package:
15-
os.mkdir(slash_name)
22+
try:
23+
os.mkdir(slash_name)
24+
except FileExistsError:
25+
pass
1626
fname = slash_name + '/__init__.py'
1727
else:
1828
fname = slash_name + '.py'
@@ -25,12 +35,28 @@ def main():
2535
TYPE_CHECKING = False
2636
if TYPE_CHECKING:
2737
from {mod} import * # noqa
38+
''')
39+
if mod in COMPAT:
40+
for k, v in COMPAT[mod].items():
41+
f.write(f' {k} = {v} # noqa\n')
42+
f.write(f'''\
2843
import {mod} as _mod
2944
import sys as _sys
3045
_cur = _sys.modules['{nmod}']
3146
for _k in vars(_mod):
3247
if not _k.startswith('__') or _k in ('__all__', '__doc__'):
3348
setattr(_cur, _k, getattr(_mod, _k))
49+
''')
50+
if mod in COMPAT:
51+
for k, v in COMPAT[mod].items():
52+
f.write(f"_cur.{k} = _mod.{v}\n")
53+
f.write(f'''\
54+
if hasattr(_cur, '__all__'):
55+
_cur.__all__ = _cur.__all__ + [
56+
{',\n '.join(repr(k) for k in COMPAT[mod])},
57+
]
58+
''')
59+
f.write(f'''\
3460
del _cur
3561
del _sys
3662
del _mod

0 commit comments

Comments
 (0)