diff --git a/gel/abstract.py b/gel/abstract.py index 22264405..2f21f5d8 100644 --- a/gel/abstract.py +++ b/gel/abstract.py @@ -49,6 +49,84 @@ class QueryWithArgs(typing.NamedTuple): kwargs: typing.Dict[str, typing.Any] input_language: protocol.InputLanguage = protocol.InputLanguage.EDGEQL + @classmethod + def from_( + cls, + query: Query, + args: typing.Tuple, + kwargs: typing.Dict[str, typing.Any], + default_input_language: protocol.InputLanguage = ( + protocol.InputLanguage.EDGEQL + ), + ) -> typing.Self: + if isinstance(query, str): + return cls(query, args, kwargs, default_input_language) + else: + return query.as_query_with_args(*args, **kwargs) + + +class AsQueryWithArgs(abc.ABC): + @abc.abstractmethod + def as_query_with_args( + self, + *args, + **kwargs, + ) -> QueryWithArgs: ... + + +Query = typing.Union[str, AsQueryWithArgs] + + +# class Batch(AsQueryWithArgs): +# def __init__(self, *queries: Query): +# self._queries = queries +# +# def as_query_with_args(self, *args, **kwargs) -> QueryWithArgs: +# query_texts = [] +# query_args = () +# query_kwargs = {} +# for q in self._queries: +# query = QueryWithArgs.from_(q, args, kwargs, input_language) +# +# text = query.query.rstrip() +# if not text.endswith(";"): +# text += ";" +# query_texts.append(text) +# +# query_args += query.args +# +# for k, v in query.kwargs.items(): +# if k in query_kwargs: +# if query_kwargs[k] != v: +# raise errors.InvalidArgumentError( +# f"conflicting values for argument {k!r}" +# ) +# else: +# query_kwargs[k] = v +# +# return QueryWithArgs( +# query="\n\n".join(query_texts), +# args=query_args, +# kwargs=query_kwargs, +# input_language=input_language, +# ) +# +# +def batch( + *queries: Query, + input_language: protocol.InputLanguage = protocol.InputLanguage.EDGEQL, +) -> QueryWithArgs: + query_text = "" + for q in queries: + if isinstance(q, str): + query = QueryWithArgs(q, (), {}, input_language) + else: + query = q.as_query_with_args(input_language=input_language) + + return QueryWithArgs( + query="\n\n".join(q.query for q in queries), + ) + class QueryCache(typing.NamedTuple): codecs_registry: protocol.CodecsRegistry @@ -211,9 +289,9 @@ class ReadOnlyExecutor(BaseReadOnlyExecutor): def _query(self, query_context: QueryContext): ... - def query(self, query: str, *args, **kwargs) -> list: + def query(self, query: Query, *args, **kwargs) -> list: return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_opts, retry_options=self._get_retry_options(), @@ -223,10 +301,10 @@ def query(self, query: str, *args, **kwargs) -> list: )) def query_single( - self, query: str, *args, **kwargs + self, query: Query, *args, **kwargs ) -> typing.Union[typing.Any, None]: return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_single_opts, retry_options=self._get_retry_options(), @@ -235,9 +313,11 @@ def query_single( annotations=self._get_annotations(), )) - def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: + def query_required_single( + self, query: Query, *args, **kwargs + ) -> typing.Any: return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_required_single_opts, retry_options=self._get_retry_options(), @@ -246,9 +326,9 @@ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: annotations=self._get_annotations(), )) - def query_json(self, query: str, *args, **kwargs) -> str: + def query_json(self, query: Query, *args, **kwargs) -> str: return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_json_opts, retry_options=self._get_retry_options(), @@ -257,9 +337,9 @@ def query_json(self, query: str, *args, **kwargs) -> str: annotations=self._get_annotations(), )) - def query_single_json(self, query: str, *args, **kwargs) -> str: + def query_single_json(self, query: Query, *args, **kwargs) -> str: return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_single_json_opts, retry_options=self._get_retry_options(), @@ -268,9 +348,9 @@ def query_single_json(self, query: str, *args, **kwargs) -> str: annotations=self._get_annotations(), )) - def query_required_single_json(self, query: str, *args, **kwargs) -> str: + def query_required_single_json(self, query: Query, *args, **kwargs) -> str: return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_required_single_json_opts, retry_options=self._get_retry_options(), @@ -299,9 +379,9 @@ def query_sql(self, query: str, *args, **kwargs) -> list[datatypes.Record]: def _execute(self, execute_context: ExecuteContext): ... - def execute(self, commands: str, *args, **kwargs) -> None: + def execute(self, commands: Query, *args, **kwargs) -> None: self._execute(ExecuteContext( - query=QueryWithArgs(commands, args, kwargs), + query=QueryWithArgs.from_(commands, args, kwargs), cache=self._get_query_cache(), state=self._get_state(), warning_handler=self._get_warning_handler(), @@ -338,9 +418,9 @@ class AsyncIOReadOnlyExecutor(BaseReadOnlyExecutor): async def _query(self, query_context: QueryContext): ... - async def query(self, query: str, *args, **kwargs) -> list: + async def query(self, query: Query, *args, **kwargs) -> list: return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_opts, retry_options=self._get_retry_options(), @@ -349,9 +429,9 @@ async def query(self, query: str, *args, **kwargs) -> list: annotations=self._get_annotations(), )) - async def query_single(self, query: str, *args, **kwargs) -> typing.Any: + async def query_single(self, query: Query, *args, **kwargs) -> typing.Any: return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_single_opts, retry_options=self._get_retry_options(), @@ -361,13 +441,10 @@ async def query_single(self, query: str, *args, **kwargs) -> typing.Any: )) async def query_required_single( - self, - query: str, - *args, - **kwargs + self, query: Query, *args, **kwargs ) -> typing.Any: return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_required_single_opts, retry_options=self._get_retry_options(), @@ -376,9 +453,9 @@ async def query_required_single( annotations=self._get_annotations(), )) - async def query_json(self, query: str, *args, **kwargs) -> str: + async def query_json(self, query: Query, *args, **kwargs) -> str: return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_json_opts, retry_options=self._get_retry_options(), @@ -387,9 +464,9 @@ async def query_json(self, query: str, *args, **kwargs) -> str: annotations=self._get_annotations(), )) - async def query_single_json(self, query: str, *args, **kwargs) -> str: + async def query_single_json(self, query: Query, *args, **kwargs) -> str: return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_single_json_opts, retry_options=self._get_retry_options(), @@ -399,13 +476,10 @@ async def query_single_json(self, query: str, *args, **kwargs) -> str: )) async def query_required_single_json( - self, - query: str, - *args, - **kwargs + self, query: Query, *args, **kwargs ) -> str: return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), + query=QueryWithArgs.from_(query, args, kwargs), cache=self._get_query_cache(), query_options=_query_required_single_json_opts, retry_options=self._get_retry_options(), @@ -414,7 +488,7 @@ async def query_required_single_json( annotations=self._get_annotations(), )) - async def query_sql(self, query: str, *args, **kwargs) -> typing.Any: + async def query_sql(self, query: Query, *args, **kwargs) -> typing.Any: return await self._query(QueryContext( query=QueryWithArgs( query, @@ -434,16 +508,16 @@ async def query_sql(self, query: str, *args, **kwargs) -> typing.Any: async def _execute(self, execute_context: ExecuteContext) -> None: ... - async def execute(self, commands: str, *args, **kwargs) -> None: + async def execute(self, commands: Query, *args, **kwargs) -> None: await self._execute(ExecuteContext( - query=QueryWithArgs(commands, args, kwargs), + query=QueryWithArgs.from_(commands, args, kwargs), cache=self._get_query_cache(), state=self._get_state(), warning_handler=self._get_warning_handler(), annotations=self._get_annotations(), )) - async def execute_sql(self, commands: str, *args, **kwargs) -> None: + async def execute_sql(self, commands: Query, *args, **kwargs) -> None: await self._execute(ExecuteContext( query=QueryWithArgs( commands, diff --git a/gel/ai/metadata_filter.py b/gel/ai/metadata_filter.py deleted file mode 100644 index 5d39f306..00000000 --- a/gel/ai/metadata_filter.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import Union, List -from enum import Enum - - -class FilterOperator(str, Enum): - EQ = "=" - NE = "!=" - GT = ">" - LT = "<" - GTE = ">=" - LTE = "<=" - IN = "in" - NOT_IN = "not in" - LIKE = "like" - ILIKE = "ilike" - ANY = "any" - ALL = "all" - CONTAINS = "contains" - EXISTS = "exists" - - -class FilterCondition(str, Enum): - AND = "and" - OR = "or" - - -class MetadataFilter: - """Represents a single metadata filter condition.""" - - def __init__( - self, - key: str, - value: Union[int, float, str], - operator: FilterOperator = FilterOperator.EQ, - ): - self.key = key - self.value = value - self.operator = operator - - def __repr__(self): - value = f'"{self.value}"' if isinstance(self.value, str) else self.value - return ( - f'MetadataFilter(key="{self.key}", ' - f"value={value}, " - f'operator="{self.operator.value}")' - ) - - -class CompositeFilter: - """ - Allows grouping multiple MetadataFilter instances using AND/OR conditions. - """ - - def __init__( - self, - filters: List[Union["CompositeFilter", MetadataFilter]], - condition: FilterCondition = FilterCondition.AND, - ): - self.filters = filters - self.condition = condition - - def __repr__(self): - return ( - f'CompositeFilter(condition="{self.condition.value}", ' - f"filters={self.filters})" - ) - - -def get_filter_clause(filters: CompositeFilter) -> str: - """ - Get the filter clause for a given CompositeFilter. - """ - subclauses = [] - for filter in filters.filters: - subclause = "" - - if isinstance(filter, CompositeFilter): - subclause = get_filter_clause(filter) - elif isinstance(filter, MetadataFilter): - formatted_value = ( - f'"{filter.value}"' - if isinstance(filter.value, str) - else filter.value - ) - - match filter.operator: - case ( - FilterOperator.EQ - | FilterOperator.NE - | FilterOperator.GT - | FilterOperator.GTE - | FilterOperator.LT - | FilterOperator.LTE - | FilterOperator.LIKE - | FilterOperator.ILIKE - ): - subclause = ( - f'json_get(.metadata, "{filter.key}") ' - f"{filter.operator.value} {formatted_value}" - ) - - case FilterOperator.IN | FilterOperator.NOT_IN: - subclause = ( - f'json_get(.metadata, "{filter.key}") ' - f"{filter.operator.value} " - f"array_unpack({formatted_value})" - ) - - case FilterOperator.ANY | FilterOperator.ALL: - subclause = ( - f"{filter.operator.value}" - f'(json_get(.metadata, "{filter.key}") = ' - f"array_unpack({formatted_value}))" - ) - - case FilterOperator.CONTAINS | FilterOperator.EXISTS: - subclause = ( - f'contains(json_get(.metadata, "{filter.key}"), ' - f"{formatted_value})" - ) - case _: - raise ValueError(f"Unknown operator: {filter.operator}") - - subclauses.append(subclause) - - if filters.condition in {FilterCondition.AND, FilterCondition.OR}: - filter_clause = f" {filters.condition.value} ".join(subclauses) - return ( - "(" + filter_clause + ")" if len(subclauses) > 1 else filter_clause - ) - else: - raise ValueError(f"Unknown condition: {filters.condition}") diff --git a/gel/ai/vectorstore.py b/gel/ai/vectorstore.py index 3ce5b3db..5086f700 100644 --- a/gel/ai/vectorstore.py +++ b/gel/ai/vectorstore.py @@ -1,7 +1,24 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Extension Vectorstore Binding # ---------------------------- # -# `GelVectorstore` is designed to integrate with vector databases following +# `Vectorstore` is designed to integrate with vector databases following # LangChain-LlamaIndex conventions. It enables interaction with embedding models # (both within and outside of Gel) through a simple interface. # @@ -9,297 +26,408 @@ # text, images, or any other embeddings. For example, CLIP can be wrapped into # this interface to generate and store image embeddings. -import gel + +from __future__ import annotations +from typing import ( + Optional, + TypeVar, + Any, + List, + Dict, + Generic, + Self, + Union, + overload, + TYPE_CHECKING, + Coroutine, +) + +import abc +import array +import dataclasses +import enum import json +import textwrap import uuid -from dataclasses import dataclass, field -from abc import abstractmethod -from typing import Optional, TypeVar, Any, List, Dict, Generic -from jinja2 import Template -from .metadata_filter import ( - get_filter_clause, - CompositeFilter, -) +from gel import abstract +from gel import errors +from gel import quote +from gel.protocol import protocol -BATCH_ADD_QUERY = Template( - """ - with items := json_array_unpack($items) - select ( - for item in items union ( - insert {{record_type}} { - collection := $collection_name, - text := item['text'], - embedding := >item['embedding'], - metadata := item['metadata'] - } - ) - ) - """.strip() -) -DELETE_BY_IDS_QUERY = Template( - """ - delete {{record_type}} - filter .id in array_unpack(>$ids) - and .collection = $collection_name; - """.strip() -) +if TYPE_CHECKING: + try: + import numpy as np + import numpy.typing as npt -SEARCH_QUERY = Template( - """ - with collection_records := ( - select {{record_type}} - filter .collection = $collection_name - and exists(.embedding) - ) - select collection_records { - id, - text, - embedding, - metadata, - cosine_similarity := 1 - ext::pgvector::cosine_distance( - .embedding, $query_embedding), - } - {{filter_expression}} - order by .cosine_similarity desc empty last - limit $limit; - """ -) + Vector = Union[ + List[float], array.array[float], npt.NDArray[np.float32] + ] + except ImportError: + Vector = Union[List[float], array.array[float]] -GET_BY_IDS_QUERY = Template( - """ - select {{record_type}} { - id, - text, - embedding, - metadata, - } - filter .id in array_unpack(>$ids) - and .collection = $collection_name; - """.strip() -) -UPDATE_QUERY = Template( - """ - with updates := array_unpack(>$updates) - update {{record_type}} - filter .id = $id and .collection = $collection_name - set { - text := $text if 'text' in updates else .text, - embedding := $embedding - if 'embedding' in updates - else .embedding, - metadata := $metadata - if 'metadata' in updates - else .metadata, - }; - """.strip() -) +class Query(abstract.AsQueryWithArgs): + def __init__(self, query: str, **kwargs): + self.query = query + self.kwargs = kwargs + def as_query_with_args(self, *args, **kwargs) -> abstract.QueryWithArgs: + if args: + raise errors.InvalidArgumentError( + "this query does not accept positional arguments" + ) + return abstract.QueryWithArgs( + query=self.query, + args=args, + kwargs={**self.kwargs, **kwargs}, + input_language=protocol.InputLanguage.EdgeQL, + ) -@dataclass -class InsertItem: - """An item whose embedding will be created and stored - alongside the item in the vector store.""" - text: str - metadata: Dict[str, Any] = field(default_factory=dict) +@dataclasses.dataclass(kw_only=True) +class AddRecord(abstract.AsQueryWithArgs): + """A record to be added to the vector store with embedding pre-computed.""" + record_type: str + collection_name: str + embedding: Vector + text: Optional[str] + metadata: Optional[Dict[str, Any]] + + def asdict(self, json_compat: bool = False, **override) -> Dict[str, Any]: + rv = dataclasses.asdict(self) + rv.pop("record_type") + if self.metadata is not None: + rv["metadata"] = json.dumps(self.metadata) + rv.update(override) + if json_compat and hasattr(rv["embedding"], "tolist"): + rv["embedding"] = rv["embedding"].tolist() + return rv + + def as_query_with_args(self, *args, **kwargs) -> abstract.QueryWithArgs: + if args: + raise errors.InvalidArgumentError( + "this query does not accept positional arguments" + ) + return abstract.QueryWithArgs( + query=textwrap.dedent( + f""" + with rec := insert {quote.quote_ident(self.record_type)} {{ + collection := $collection_name, + text := $text, + embedding := $embedding, + metadata := $metadata, + }} + select rec.id + """ + ), + args=args, + kwargs=self.asdict(**kwargs), + input_language=protocol.InputLanguage.EdgeQL, + ) -@dataclass -class InsertRecord: - """A record to be added to the vector store with embedding pre-computed.""" - embedding: List[float] - text: Optional[str] = None - metadata: Dict[str, Any] = field(default_factory=dict) +class AddRecords(abstract.AsQueryWithArgs): + """Add multiple records to the vector store in a single transaction.""" + def __init__(self, *records: AddRecord): + record_type = set(record.record_type for record in records) + if len(record_type) == 0: + raise errors.InvalidArgumentError("no records provided") + if len(record_type) > 1: + raise errors.InvalidArgumentError( + f"all records must have the same record type, " + f"got {record_type}" + ) + self.record_type = record_type.pop() + self.records = records -@dataclass(init=False) -class Record: - """A record retrieved from the vector store, or an update record. + def as_query_with_args(self, *args, **kwargs) -> abstract.QueryWithArgs: + if args: + raise errors.InvalidArgumentError( + "this query does not accept positional arguments" + ) + return abstract.QueryWithArgs( + query=textwrap.dedent( + f""" + with + items := json_array_unpack($items), + recs := for item in items union ( + insert {quote.quote_ident(self.record_type)} {{ + collection := item['collection_name'], + text := item['text'], + embedding := >item['embedding'], + metadata := to_json(item['metadata']), + }} + ) + select recs.id + """ + ), + args=args, + kwargs={ + "items": json.dumps( + [ + r.asdict(json_compat=True, **kwargs) + for r in self.records + ] + ) + }, + input_language=protocol.InputLanguage.EdgeQL, + ) - Custom `__init__` so we can detect which fields the user passed - (even if they pass None or {}). - """ - id: uuid.UUID - text: Optional[str] = None - embedding: Optional[List[float]] = None - metadata: Dict[str, Any] = field(default_factory=dict) - # We'll fill these dynamically in __init__ - _explicitly_set_fields: set = field(default_factory=set, repr=False) +JsonValue = Union[int, float, str, bool] - def __init__(self, id: uuid.UUID, **kwargs): - """ - Force the user to provide `id` positionally/explicitly, - then capture any *other* fields in **kwargs. - """ - # For text, embedding, metadata, we use what's in kwargs - # or fall back to the default already on the class. - self.id = id - self.text = kwargs.get("text", None) - self.embedding = kwargs.get("embedding", None) - self.metadata = kwargs.get("metadata", {}) +class FilterOperator(str, enum.Enum): + EQ = "=" + NE = "!=" + GT = ">" + GTE = ">=" + LT = "<" + LTE = "<=" - # Mark which fields were actually passed by the user (ignore 'id'). - # So if user calls Record(id=..., text=None), "text" will appear here. - object.__setattr__(self, "_explicitly_set_fields", set(kwargs.keys())) - def is_field_set(self, field: str) -> bool: - """Check if a field was explicitly set in the constructor call.""" - return field in self._explicitly_set_fields +class SearchRecordQuery(abstract.AsQueryWithArgs): + def __init__( + self, + record_type: str, + collection_name: str, + embedding: Vector, + limit: int = 4, + ): + self.record_type = record_type + self.collection_name = collection_name + self.embedding = embedding + self.limit = limit + self.filter_args = [] + self.filters = [] + + def filter(self, expr: str, *args: Any) -> Self: + if self.filter_args and args: + raise errors.InvalidArgumentError( + "filter() with arguments can only be called once before " + "adding any filter with arguments" + ) + self.filters.append(f"({expr})") + self.filter_args.extend(args) + return self + + @overload + def filter_metadata( + self, *path: str, eq: JsonValue, default: bool = False + ) -> Self: ... + + @overload + def filter_metadata( + self, *path: str, ne: JsonValue, default: bool = False + ) -> Self: ... + + @overload + def filter_metadata( + self, *path: str, gt: JsonValue, default: bool = False + ) -> Self: ... + + @overload + def filter_metadata( + self, *path: str, gte: JsonValue, default: bool = False + ) -> Self: ... + + @overload + def filter_metadata( + self, *path: str, lt: JsonValue, default: bool = False + ) -> Self: ... + + @overload + def filter_metadata( + self, *path: str, lte: JsonValue, default: bool = False + ) -> Self: ... + + def filter_metadata(self, *path: str, **op_vals: JsonValue) -> Self: + if not path: + raise errors.InterfaceError( + "at least one path element is required" + ) + default = str(op_vals.pop("default", False)).lower() + if len(op_vals) != 1: + raise errors.InterfaceError( + "expected exactly one operator-value pair" + ) -@dataclass -class SearchResult(Record): - """A search result from the vector store.""" + op_str, value = op_vals.popitem() + if isinstance(value, str): + typ = "str" + elif isinstance(value, bool): # bool is a subclass of int, goes first + typ = "bool" + elif isinstance(value, int): + typ = "int64" + elif isinstance(value, float): + typ = "float64" + else: + raise errors.InterfaceError( + f"unsupported value type: {type(value).__name__}" + ) - cosine_similarity: float = 0.0 + path_param = ", ".join(quote.quote_literal(p) for p in path) + left = f"<{typ}>json_get(.metadata, {path_param})" + op = FilterOperator(op_str.upper()) + right = f"<{typ}>${len(self.filter_args)}" + self.filters.append(f"(({left} {op} {right}) ?? {default})") + self.filter_args.append(value) + return self -T = TypeVar("T") + def limit(self, limit: int) -> Self: + self.limit = limit + return self + def as_query_with_args(self, *args, **kwargs) -> abstract.QueryWithArgs: + if args: + raise errors.InvalidArgumentError( + "this query does not accept positional arguments" + ) + c = len(self.filter_args) + if self.filters: + filter_expression = "filter " + " and ".join(self.filters) + else: + filter_expression = "" + return abstract.QueryWithArgs( + query=textwrap.dedent( + f""" + with collection_records := ( + select {quote.quote_ident(self.record_type)} + filter .collection = ${c} + and exists(.embedding) + ) + select collection_records {{ + id, + text, + embedding, + metadata, + cosine_similarity := 1 - ext::pgvector::cosine_distance( + .embedding, ${c + 1}), + }} + {filter_expression} + order by .cosine_similarity desc empty last + limit ${c + 2}; + """ + ), + args=( + *self.filter_args, + kwargs.pop("collection_name", self.collection_name), + kwargs.pop("embedding", self.embedding), + kwargs.pop("limit", self.limit), + ), + kwargs={}, + input_language=protocol.InputLanguage.EdgeQL, + ) -class BaseEmbeddingModel(Generic[T]): - """ - Abstract base class for embedding models. - Any embedding model used with `GelVectorstore` must implement this - interface. The model is expected to convert input data (text, images, etc.) - into a numerical vector representation. - """ +T = TypeVar("T") - @abstractmethod - def __call__(self, item: T) -> List[float]: - """ - Convert an input item into a list of floating-point values (vector - embedding). Must be implemented in subclasses. - """ - raise NotImplementedError +class BaseEmbeddingModel(abc.ABC, Generic[T]): @property - @abstractmethod + @abc.abstractmethod def dimensions(self) -> int: """ Return the number of dimensions in the embedding vector. Must be implemented in subclasses. """ - raise NotImplementedError + ... @property - @abstractmethod + @abc.abstractmethod def target_type(self) -> TypeVar: """ Return the expected data type of the input (e.g., str for text, image for vision models). Must be implemented in subclasses. """ - raise NotImplementedError + ... - -class GelVectorstore: """ - A framework-agnostic interface for interacting with Gel's ext::vectorstore. + Abstract base class for embedding models. - This class provides methods for storing, retrieving, and searching - vector embeddings. It follows vector database conventions and supports - different embedding models. + Any embedding model used with `Vectorstore` must implement this + interface. The model is expected to convert input data (text, images, etc.) + into a numerical vector representation. """ - def __init__( + +class EmbeddingModel(BaseEmbeddingModel[T], Generic[T]): + def store( self, - embedding_model: Optional[BaseEmbeddingModel] = None, collection_name: str = "default", record_type: str = "ext::vectorstore::DefaultRecord", - client_config: Optional[Dict[str, Any]] = None, - ): - """Initialize a new vector store instance. + ) -> Vectorstore: + return Vectorstore( + embedding_model=self, + collection_name=collection_name, + record_type=record_type, + ) - Args: - embedding_model (BaseEmbeddingModel): The embedding model used to - generate vectors. - collection_name (str): The name of the collection. - record_type (str): The schema type (table name) for storing records. - client_config (Optional[dict]): The config for the Gel client. + @abc.abstractmethod + def generate(self, item: T) -> Vector: """ - self.embedding_model = embedding_model - self.collection_name = collection_name - self.record_type = record_type - self.gel_client = gel.create_client(**(client_config or {})) - - def add_items(self, items: List[InsertItem]) -> List[uuid.UUID]: + Convert an input item into a list of floating-point values (vector + embedding). Must be implemented in subclasses. """ - Add multiple items to the vector store in a single transaction. - Embeddinsg will be generated and stored for all items. + ... - Args: - items (List[InsertItem]): List of items to add. Each contains: - - text (str): The text content to be embedded - - metadata (Dict[str, Any]): Additional data to store + @abc.abstractmethod + def generate_text(self, text: str) -> Vector: ... - Returns: - List[uuid.UUID]: List of database record IDs for the inserted items. - """ - items_with_embeddings = [ - InsertRecord( - text=item.text, - embedding=(self.embedding_model(item.text)), - metadata=item.metadata, - ) - for item in items - ] - return self.add_vectors(items_with_embeddings) - def add_vectors(self, records: List[InsertRecord]) -> List[uuid.UUID]: - """Add pre-computed vector embeddings to the store. +@dataclasses.dataclass +class BaseVectorstore: + """ + A framework-agnostic interface for interacting with Gel's ext::vectorstore. - Use this method when you have already generated embeddings and want to - store them directly without re-computing them. + This class provides methods for storing, retrieving, and searching + vector embeddings. It follows vector database conventions and supports + different embedding models. + """ - Args: - records (List[InsertRecord]): List of records. Each contains: - - embedding ([List[float]): Pre-computed embeddings - - text (Optional[str]): Original text content - - metadata ([Dict[str, Any]): Additional data to store + collection_name: str = "default" + record_type: str = "ext::vectorstore::DefaultRecord" - Returns: - List[uuid.UUID]: List of database record IDs for the inserted items. - """ - results = self.gel_client.query( - query=BATCH_ADD_QUERY.render(record_type=self.record_type), + def add_embedding( + self, embedding: Vector, text: Optional[str] = None, **metadata + ) -> AddRecord: + return AddRecord( + record_type=self.record_type, collection_name=self.collection_name, - items=json.dumps( - [ - { - "text": record.text, - "embedding": record.embedding, - "metadata": record.metadata or {}, - } - for record in records - ] - ), + embedding=embedding, + text=text, + metadata=metadata or None, ) - return [result.id for result in results] - def delete(self, ids: List[uuid.UUID]) -> List[uuid.UUID]: + def delete(self, *ids: uuid.UUID) -> Query: """Delete records from the vector store by their IDs. Args: ids (List[uuid.UUID]): List of record IDs to delete. Returns: - List[uuid.UUID]: List of deleted record IDs. + Query: Executable Query, returning the deleted IDs. """ - results = self.gel_client.query( - query=DELETE_BY_IDS_QUERY.render(record_type=self.record_type), + return Query( + textwrap.dedent( + f""" + with recs := delete {quote.quote_ident(self.record_type)} + filter .id in array_unpack(>$ids) + and .collection = $collection_name; + select recs.id; + """ + ), collection_name=self.collection_name, - ids=ids, + ids=list(ids), ) - return [result.id for result in results] - def get_by_ids(self, ids: List[uuid.UUID]) -> List[Record]: + def get_by_ids(self, *ids: uuid.UUID) -> Query: """Retrieve specific records by their IDs. Args: @@ -312,27 +440,125 @@ def get_by_ids(self, ids: List[uuid.UUID]) -> List[Record]: - embedding (Optional[List[float]]): The stored vector embedding - metadata (Optional[Dict[str, Any]]): Any associated metadata """ - results = self.gel_client.query( - query=GET_BY_IDS_QUERY.render(record_type=self.record_type), + return Query( + textwrap.dedent( + f""" + select {self.record_type} {{ + id, + text, + embedding, + metadata, + }} + filter .id in array_unpack(>$ids) + and .collection = $collection_name; + """ + ), collection_name=self.collection_name, - ids=ids, + ids=list(ids), ) - return [ - Record( - id=result.id, - text=result.text, - embedding=result.embedding and list(result.embedding), - metadata=(json.loads(result.metadata)), + + def search_by_vector(self, vector: Vector) -> SearchRecordQuery: + """Search using a pre-computed vector embedding. + + Useful when you have already computed the embedding or want to search + with a modified/combined embedding vector. + + Args: + vector (List[float]): The query embedding to search with. + Must match the dimensionality of stored embeddings. + filter_expression (str): Filter expression for metadata filtering. + limit (Optional[int]): Max number of results to return. + Defaults to 4. + + Returns: + List[SearchResult]: List of similar items, ordered by similarity. + Each result contains: + - id (uuid.UUID): The record's unique identifier + - text (Optional[str]): The original text content + - embedding (List[float]): The stored vector embedding + - metadata (Optional[Dict[str, Any]]): Any associated metadata + - cosine_similarity (float): Similarity score + (higher is more similar) + """ + return SearchRecordQuery( + self.record_type, self.collection_name, vector + ) + + async def _update_record(self, id: uuid.UUID, **kwargs) -> Query: + conditions = [] + params = {"id": id, "collection_name": self.collection_name} + if "text" in kwargs: + text = kwargs.pop("text") + conditions.append("text := $text") + params["text"] = text + if "embedding" not in kwargs: + kwargs["embedding"] = await self._generate_vector_from_text( + text + ) + if "embedding" in kwargs: + conditions.append("embedding := $embedding") + params["embedding"] = kwargs.pop("embedding") + if "metadata" in kwargs: + conditions.append("metadata := $metadata") + params["metadata"] = json.dumps(kwargs.pop("metadata")) + if not conditions: + raise errors.InterfaceError("No fields specified for update.") + if kwargs: + raise errors.InterfaceError( + f"Unexpected fields for update: {', '.join(kwargs.keys())}" ) - for result in results - ] + return Query( + textwrap.dedent( + f""" + update {quote.quote_ident(self.record_type)} + filter .id = $id + and .collection = $collection_name + set {{ + {", ".join(conditions)} + }}; + """ + ), + **params, + ) - def search_by_item( - self, - item: Any, - filters: Optional[CompositeFilter] = None, - limit: Optional[int] = 4, - ) -> List[SearchResult]: + async def _generate_vector_from_text(self, text: str) -> Vector: + raise NotImplementedError() + + +V = TypeVar("V") + + +def _iter_coroutine(coro: Coroutine[Any, Any, V]) -> V: + try: + coro.send(None) + except StopIteration as ex: + return ex.value + finally: + coro.close() + + +@dataclasses.dataclass +class Vectorstore(BaseVectorstore, Generic[T]): + embedding_model: Optional[EmbeddingModel] = None + + async def _generate_vector_from_text(self, text: str) -> Vector: + if self.embedding_model is None: + raise errors.InterfaceError( + "No embedding model provided to generate vector for text." + ) + + return self.embedding_model.generate_text(text) + + def add_text(self, text: str, **metadata) -> AddRecord: + return AddRecord( + record_type=self.record_type, + collection_name=self.collection_name, + embedding=_iter_coroutine(self._generate_vector_from_text(text)), + text=text, + metadata=metadata or None, + ) + + def search_by_item(self, item: T) -> SearchRecordQuery: """Search for similar items in the vector store. This method: @@ -344,9 +570,6 @@ def search_by_item( Args: item (Any): The query item to find similar matches for. Must be compatible with the embedding model's target_type. - filters (Optional[CompositeFilter]): Metadata-based filters to use. - limit (Optional[int]): Max number of results to return. - Defaults to 4. Returns: List[SearchResult]: List of similar items, ordered by similarity. @@ -358,31 +581,142 @@ def search_by_item( - cosine_similarity (float): Similarity score (higher is more similar) """ - vector = self.embedding_model(item) - filter_expression = ( - f"filter {get_filter_clause(filters)}" if filters else "" - ) - return self.search_by_vector( - vector=vector, filter_expression=filter_expression, limit=limit + if self.embedding_model is None: + raise errors.InterfaceError( + "No embedding model provided to generate vector." + ) + + return SearchRecordQuery( + self.record_type, + self.collection_name, + self.embedding_model.generate(item), ) - def search_by_vector( + @overload + def update_record(self, id: uuid.UUID, *, embedding: Vector) -> Query: ... + + @overload + def update_record(self, id: uuid.UUID, *, text: str) -> Query: ... + + @overload + def update_record( + self, id: uuid.UUID, *, metadata: Optional[Dict[str, Any]] + ) -> Query: ... + + @overload + def update_record( + self, id: uuid.UUID, *, text: str, embedding: Vector + ) -> Query: ... + + @overload + def update_record( + self, id: uuid.UUID, *, text: str, metadata: Optional[Dict[str, Any]] + ) -> Query: ... + + @overload + def update_record( self, - vector: List[float], - filter_expression: str = "", - limit: Optional[int] = 4, - ) -> List[SearchResult]: - """Search using a pre-computed vector embedding. + id: uuid.UUID, + *, + embedding: Vector, + metadata: Optional[Dict[str, Any]], + ) -> Query: ... + + @overload + def update_record( + self, + id: uuid.UUID, + *, + text: str, + embedding: Vector, + metadata: Optional[Dict[str, Any]], + ) -> Query: ... + + def update_record(self, id: uuid.UUID, **kwargs) -> Query: + """Update an existing record in the vector store. - Useful when you have already computed the embedding or want to search - with a modified/combined embedding vector. + Only specified fields will be updated. If text is provided but not + embedding, a new embedding will be automatically generated. Args: - vector (List[float]): The query embedding to search with. - Must match the dimensionality of stored embeddings. - filter_expression (str): Filter expression for metadata filtering. - limit (Optional[int]): Max number of results to return. - Defaults to 4. + Record: + - id (uuid.UUID): The ID of the record to update + - text (Optional[str]): New text content. If provided without + embedding, a new embedding will be generated. + - embedding (Optional[List[float]]): New vector embedding. + - metadata (Optional[Dict[str, Any]]): New metadata to store + with the record. Completely replaces existing metadata. + Returns: + Optional[IdRecord]: The updated record's ID if found and updated, + None if no record was found with the given ID. + Raises: + ValueError: If no fields are specified for update. + """ + return _iter_coroutine(self._update_record(id, **kwargs)) + + +class AsyncEmbeddingModel(BaseEmbeddingModel, Generic[T]): + """ + Abstract base class for embedding models. + + Any embedding model used with `Vectorstore` must implement this + interface. The model is expected to convert input data (text, images, etc.) + into a numerical vector representation. + """ + + def store( + self, + collection_name: str = "default", + record_type: str = "ext::vectorstore::DefaultRecord", + ) -> AsyncVectorstore: + return AsyncVectorstore( + embedding_model=self, + collection_name=collection_name, + record_type=record_type, + ) + + @abc.abstractmethod + async def generate(self, item: T) -> Vector: + """ + Convert an input item into a list of floating-point values (vector + embedding). Must be implemented in subclasses. + """ + ... + + @abc.abstractmethod + async def generate_text(self, text: str) -> Vector: ... + + +@dataclasses.dataclass +class AsyncVectorstore(BaseVectorstore, Generic[T]): + embedding_model: Optional[AsyncEmbeddingModel] = None + + async def add_text(self, text: str, **metadata) -> AddRecord: + if self.embedding_model is None: + raise errors.InterfaceError( + "No embedding model provided to generate vector for text." + ) + + return AddRecord( + record_type=self.record_type, + collection_name=self.collection_name, + embedding=await self.embedding_model.generate_text(text), + text=text, + metadata=metadata or None, + ) + + async def search_by_item(self, item: str) -> SearchRecordQuery: + """Search for similar items in the vector store. + + This method: + 1. Generates an embedding for the input item + 2. Finds records with similar embeddings + 3. Optionally filters results based on metadata + 4. Returns the most similar items up to the specified limit + + Args: + item (Any): The query item to find similar matches for. + Must be compatible with the embedding model's target_type. Returns: List[SearchResult]: List of similar items, ordered by similarity. @@ -394,27 +728,60 @@ def search_by_vector( - cosine_similarity (float): Similarity score (higher is more similar) """ - results = self.gel_client.query( - query=SEARCH_QUERY.render( - record_type=self.record_type, - filter_expression=filter_expression, - ), - collection_name=self.collection_name, - query_embedding=vector, - limit=limit, - ) - return [ - SearchResult( - id=result.id, - text=result.text, - embedding=list(result.embedding) if result.embedding else None, - metadata=json.loads(result.metadata), - cosine_similarity=result.cosine_similarity, + if self.embedding_model is None: + raise errors.InterfaceError( + "No embedding model provided to generate vector." ) - for result in results - ] - def update_record(self, record: Record) -> Optional[uuid.UUID]: + return SearchRecordQuery( + self.record_type, + self.collection_name, + await self.embedding_model.generate(item), + ) + + @overload + async def update_record( + self, id: uuid.UUID, *, embedding: Vector + ) -> Query: ... + + @overload + async def update_record(self, id: uuid.UUID, *, text: str) -> Query: ... + + @overload + async def update_record( + self, id: uuid.UUID, *, metadata: Optional[Dict[str, Any]] + ) -> Query: ... + + @overload + async def update_record( + self, id: uuid.UUID, *, text: str, embedding: Vector + ) -> Query: ... + + @overload + async def update_record( + self, id: uuid.UUID, *, text: str, metadata: Optional[Dict[str, Any]] + ) -> Query: ... + + @overload + async def update_record( + self, + id: uuid.UUID, + *, + embedding: Vector, + metadata: Optional[Dict[str, Any]], + ) -> Query: ... + + @overload + async def update_record( + self, + id: uuid.UUID, + *, + text: str, + embedding: Vector, + metadata: Optional[Dict[str, Any]], + ) -> Query: ... + + async def update_record(self, id: uuid.UUID, **kwargs) -> Query: """Update an existing record in the vector store. Only specified fields will be updated. If text is provided but not @@ -429,34 +796,9 @@ def update_record(self, record: Record) -> Optional[uuid.UUID]: - metadata (Optional[Dict[str, Any]]): New metadata to store with the record. Completely replaces existing metadata. Returns: - Optional[uuid.UUID]: The updated record's ID if found and updated, + Optional[IdRecord]: The updated record's ID if found and updated, None if no record was found with the given ID. Raises: ValueError: If no fields are specified for update. """ - if not any( - record.is_field_set(field) - for field in ["text", "embedding", "metadata"] - ): - raise ValueError("No fields specified for update.") - - updates = { - field - for field in ["text", "embedding", "metadata"] - if record.is_field_set(field) - } - - if "text" in updates and record.text is not None and "embedding" not in updates: - updates.add("embedding") - record.embedding = self.embedding_model(record.text) - - result = self.gel_client.query_single( - query=UPDATE_QUERY.render(record_type=self.record_type), - collection_name=self.collection_name, - id=record.id, - updates=list(updates), - text=record.text, - embedding=record.embedding, - metadata=json.dumps(record.metadata or {}), - ) - return result.id if result else None + return await self._update_record(id, **kwargs) diff --git a/gel/quote.py b/gel/quote.py new file mode 100644 index 00000000..cb132da8 --- /dev/null +++ b/gel/quote.py @@ -0,0 +1,176 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re + + +_re_ident = re.compile( + r"""(?x) + [^\W\d]\w* # alphanumeric identifier +""" +) + +_re_ident_or_num = re.compile( + r"""(?x) + [^\W\d]\w* # alphanumeric identifier + | + ([1-9]\d* | 0) # purely integer identifier +""" +) + +_reserved_keyword = { + "like", + "do", + "listen", + "grant", + "anytype", + "ilike", + "single", + "__edgedbsys__", + "end", + "set", + "never", + "typeof", + "start", + "configure", + "rollback", + "when", + "__edgedbtpl__", + "or", + "__source__", + "filter", + "global", + "case", + "introspect", + "__default__", + "not", + "begin", + "over", + "if", + "lock", + "refresh", + "else", + "alter", + "notify", + "distinct", + "and", + "module", + "offset", + "drop", + "is", + "discard", + "anyobject", + "import", + "group", + "__subject__", + "limit", + "match", + "anyarray", + "insert", + "get", + "administer", + "delete", + "__old__", + "exists", + "true", + "select", + "analyze", + "by", + "move", + "load", + "deallocate", + "partition", + "with", + "window", + "in", + "false", + "raise", + "revoke", + "anytuple", + "commit", + "update", + "for", + "describe", + "variadic", + "fetch", + "__specified__", + "optional", + "explain", + "__new__", + "create", + "prepare", + "check", + "extending", + "detached", + "on", +} + + +def escape_string(s: str) -> str: + # characters escaped according to + # https://www.edgedb.com/docs/reference/edgeql/lexical#strings + result = s + + # escape backslash first + result = result.replace("\\", "\\\\") + + result = result.replace("'", "\\'") + result = result.replace("\b", "\\b") + result = result.replace("\f", "\\f") + result = result.replace("\n", "\\n") + result = result.replace("\r", "\\r") + result = result.replace("\t", "\\t") + + return result + + +def quote_literal(string: str) -> str: + return "'" + escape_string(string) + "'" + + +def needs_quoting(string: str, allow_reserved: bool, allow_num: bool) -> bool: + if not string or string.startswith("@") or "::" in string: + # some strings are illegal as identifiers and as such don't + # require quoting + return False + + r = _re_ident_or_num if allow_num else _re_ident + isalnum = r.fullmatch(string) + + string = string.lower() + + is_reserved = string in _reserved_keyword + + return not isalnum or (not allow_reserved and is_reserved) + + +def _quote_ident(string: str) -> str: + return "`" + string.replace("`", "``") + "`" + + +def quote_ident( + string: str, + *, + force: bool = False, + allow_reserved: bool = False, + allow_num: bool = False, +) -> str: + if force or needs_quoting(string, allow_reserved, allow_num): + return _quote_ident(string) + else: + return string