diff --git a/gel/_testbase.py b/gel/_testbase.py index b581ac90..610023d5 100644 --- a/gel/_testbase.py +++ b/gel/_testbase.py @@ -178,12 +178,12 @@ def _start_cluster(*, cleanup_atexit=True): if line.startswith(b'READY='): break else: - raise RuntimeError('not ready') + raise RuntimeError("not ready") break except Exception: time.sleep(1) else: - raise RuntimeError('server status file not found') + raise RuntimeError("server status file not found") data = json.loads(line.split(b'READY=')[1]) con_args = dict(host='localhost', port=data['port']) @@ -203,8 +203,7 @@ def _start_cluster(*, cleanup_atexit=True): ] ) else: - con_args['tls_ca_file'] = data['tls_cert_file'] - + con_args["tls_ca_file"] = data["tls_cert_file"] client = gel.create_client(password='test', **con_args) client.ensure_connected() client.execute(""" diff --git a/gel/ai/__init__.py b/gel/ai/__init__.py index c61fd7e5..619495b0 100644 --- a/gel/ai/__init__.py +++ b/gel/ai/__init__.py @@ -1,7 +1,7 @@ # -# This source file is part of the EdgeDB open source project. +# This source file is part of the Gel open source project. # -# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# Copyright 2024-present MagicStack Inc. and the Gel authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,21 @@ from .types import RAGOptions, ChatParticipantRole, Prompt, QueryContext from .core import create_rag_client, RAGClient from .core import create_async_rag_client, AsyncRAGClient +from .vectorstore import ( + create_vstore, + create_async_vstore, + Record, + Vector, + SearchResult, + EmbeddingModel, + AsyncEmbeddingModel, +) +from .metadata_filter import ( + MetadataFilter, + CompositeFilter, + FilterOperator, + FilterCondition, +) __all__ = [ "RAGOptions", @@ -29,4 +44,15 @@ "RAGClient", "create_async_rag_client", "AsyncRAGClient", + "create_vstore", + "create_async_vstore", + "Record", + "Vector", + "SearchResult", + "EmbeddingModel", + "AsyncEmbeddingModel", + "MetadataFilter", + "CompositeFilter", + "FilterOperator", + "FilterCondition", ] diff --git a/gel/ai/metadata_filter.py b/gel/ai/metadata_filter.py new file mode 100644 index 00000000..17142ad7 --- /dev/null +++ b/gel/ai/metadata_filter.py @@ -0,0 +1,134 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Union, List +from enum import Enum + + +class FilterOperator(str, Enum): + EQ = "=" + NE = "!=" + GT = ">" + LT = "<" + GTE = ">=" + LTE = "<=" + IN = "in" + NOT_IN = "not in" + LIKE = "like" + ILIKE = "ilike" + ANY = "any" + ALL = "all" + CONTAINS = "contains" + EXISTS = "exists" + + +class FilterCondition(str, Enum): + AND = "and" + OR = "or" + + +@dataclass +class MetadataFilter: + """Represents a single metadata filter condition.""" + + key: str + value: Union[int, float, str] + operator: FilterOperator = FilterOperator.EQ + + def __repr__(self): + value = f'"{self.value}"' if isinstance(self.value, str) else self.value + return ( + f'MetadataFilter(key="{self.key}", ' + f"value={value}, " + f'operator="{self.operator.value}")' + ) + + +@dataclass +class CompositeFilter: + """ + Allows grouping multiple MetadataFilter instances using AND/OR. + """ + + filters: List[Union[CompositeFilter, MetadataFilter]] + condition: FilterCondition = FilterCondition.AND + + def __repr__(self): + return ( + f'CompositeFilter(condition="{self.condition.value}", ' + f"filters={self.filters})" + ) + + +def get_filter_clause(filters: CompositeFilter) -> str: + """ + Get the filter clause for a given CompositeFilter. + """ + + subclauses = [] + for filter in filters.filters: + subclause = "" + + if isinstance(filter, CompositeFilter): + subclause = get_filter_clause(filter) + elif isinstance(filter, MetadataFilter): + formatted_value = ( + f'"{filter.value}"' + if isinstance(filter.value, str) + else filter.value + ) + + # Simple comparison operators + if filter.operator in { + FilterOperator.EQ, + FilterOperator.NE, + FilterOperator.GT, + FilterOperator.GTE, + FilterOperator.LT, + FilterOperator.LTE, + FilterOperator.LIKE, + FilterOperator.ILIKE, + }: + subclause = ( + f'json_get(.metadata, "{filter.key}") ' + f"{filter.operator.value} {formatted_value}" + ) + # casting should be fixed + elif filter.operator in {FilterOperator.IN, FilterOperator.NOT_IN}: + subclause = ( + f"{formatted_value} " + f"{filter.operator.value} " + f'json_get(.metadata, "{filter.key}")' + ) + # casting should be fixed + # works only for equality, should be updated to support the rest, + # example: select all({1, 2, 3, 4} < 4); + elif filter.operator in {FilterOperator.ANY, FilterOperator.ALL}: + subclause = ( + f"{filter.operator.value} (" + f'json_get(.metadata, "{filter.key}")' + f" = {formatted_value})" + ) + + elif filter.operator == FilterOperator.EXISTS: + subclause = f'exists json_get(.metadata, "{filter.key}")' + + # casting should be fixed + # edgeql contains supports different types like range etc which + # we don't support here + elif filter.operator == FilterOperator.CONTAINS: + subclause = ( + f'contains (json_get(.metadata, "{filter.key}"), ' + f"{formatted_value})" + ) + else: + raise ValueError(f"Unknown operator: {filter.operator}") + + subclauses.append(subclause) + + if filters.condition in {FilterCondition.AND, FilterCondition.OR}: + filter_clause = f" {filters.condition.value} ".join(subclauses) + return ( + "(" + filter_clause + ")" if len(subclauses) > 1 else filter_clause + ) + else: + raise ValueError(f"Unknown condition: {filters.condition}") diff --git a/gel/ai/vectorstore.py b/gel/ai/vectorstore.py new file mode 100644 index 00000000..db230ecf --- /dev/null +++ b/gel/ai/vectorstore.py @@ -0,0 +1,711 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Extension VectorStore Binding +# ---------------------------- +# +# `VectorStore` is designed to integrate with vector databases following +# LangChain-LlamaIndex conventions. It enables interaction with embedding models +# (both within and outside of Gel) through a simple interface. +# +# This binding does not assume a specific data type, allowing it to support +# text, images, or any other embeddings. For example, CLIP can be wrapped into +# this interface to generate and store image embeddings. + +from __future__ import annotations +from typing import ( + Optional, + TypeVar, + Generic, + Union, + Any, + List, + Dict, + TYPE_CHECKING, +) + +import gel +import json +import uuid +import abc +import array +import textwrap +import dataclasses + +from gel import quote +from .metadata_filter import ( + get_filter_clause, + CompositeFilter, +) + + +if TYPE_CHECKING: + try: + import numpy as np + import numpy.typing as npt + + Vector = Union[List[float], array.array[float], npt.NDArray[np.float32]] + except ImportError: + Vector = Union[List[float], array.array[float]] + + +def create_vstore(client: gel.Client, **kwargs) -> VectorStore: + """Create a new vector store instance.""" + + client.ensure_connected() + return VectorStore(client, **kwargs) + + +async def create_async_vstore( + client: gel.AsyncIOClient, **kwargs +) -> AsyncVectorStore: + """Create a new async vector store instance.""" + + await client.ensure_connected() + return AsyncVectorStore(client, **kwargs) + + +BATCH_ADD_QUERY = textwrap.dedent( + """ + with items := json_array_unpack($items) + for item in items union ( + insert {record_type} {{ + collection := $collection_name, + text := item['text'], + embedding := >item['embedding'], + metadata := to_json(item['metadata']) + }} + ) + """ +) + + +DELETE_BY_IDS_QUERY = textwrap.dedent( + """ + delete {record_type} + filter .id in array_unpack(>$ids) + and .collection = $collection_name; + """ +) + + +SEARCH_QUERY = """ + with collection_records := ( + select {record_type} + filter .collection = $collection_name + and exists(.embedding) + ) + select collection_records {{ + id, + text, + embedding, + metadata, + cosine_similarity := 1 - ext::pgvector::cosine_distance( + .embedding, $query_embedding), + }} + {filter_expression} + order by .cosine_similarity desc empty last + limit $limit; + """ + + +GET_BY_IDS_QUERY = """ + select {record_type} {{ + id, + text, + embedding, + metadata, + }} + filter .id in array_unpack(>$ids) + and .collection = $collection_name; + """.strip() + + +UPDATE_QUERY = """ + with updates := array_unpack(>$updates) + update {record_type} + filter .id = $id and .collection = $collection_name + set {{ + text := $text if 'text' in updates else .text, + embedding := $embedding + if 'embedding' in updates + else .embedding, + metadata := to_json($metadata) + if 'metadata' in updates + else .metadata, + }}; + """.strip() + + +@dataclasses.dataclass +class Vector: + """A vector (embeddings) along with its text and metadata. + If id is None, it is considered a new record to be inserted. Use + this when you have pre-calculated embeddings. + """ + + id: Optional[uuid.UUID] = None + embedding: Optional[Vector] = None + text: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +@dataclasses.dataclass +class Record: + """A record to be inserted into the vector store, where its + embedding will be automatically generated in the vectorstore. + Use this when you expect the vectorstore to generate the + embeddings using the embedding model you provided. + """ + + text: str + metadata: Optional[Dict[str, Any]] = None + + def to_vector(self, embedding_model: EmbeddingModel) -> Vector: + """Convert this item to a Record using the provided embedding model.""" + + return Vector( + text=self.text, + embedding=embedding_model.generate(self.text), + metadata=self.metadata, + ) + + async def ato_vector(self, embedding_model: AsyncEmbeddingModel) -> Vector: + """ + Convert this item to a Vector using the provided async embedding model. + """ + + return Vector( + text=self.text, + embedding=await embedding_model.generate(self.text), + metadata=self.metadata, + ) + + +@dataclasses.dataclass +class SearchResult: + """A search result from the vector store.""" + + id: uuid.UUID + text: Optional[str] = None + embedding: Optional[Vector] = None + metadata: Optional[Dict[str, Any]] = None + cosine_similarity: float = 0.0 + + +def _serialize_metadata(metadata: Optional[Dict[str, Any]]) -> Optional[str]: + """Helper to serialize metadata to JSON string.""" + + return json.dumps(metadata) if metadata else None + + +def _deserialize_metadata(metadata: Optional[str]) -> Optional[Dict[str, Any]]: + """Helper to deserialize metadata from JSON string.""" + + return json.loads(metadata) if metadata else None + + +_sentinel = object() + +T = TypeVar("T") + + +class EmbeddingModel(abc.ABC, Generic[T]): + """ + Embedding model. Any embedding model used with `VectorStore` + must implement this interface. The model is expected to convert input data + (text, images, etc.) into a numerical vector representation. + """ + + @abc.abstractmethod + def generate(self, item: T) -> Vector: + """Convert an input item into a list of floating-point values.""" + ... + + +class BaseVectorStore(Generic[T]): + """ + Base class for vector store. + T: The type of items that can be embedded (e.g., str for text...) + """ + + def __init__( + self, + client: Union[gel.Client, gel.AsyncIOClient], + collection_name: str = "default", + record_type: str = "ext::vectorstore::DefaultRecord", + ): + """Initialize a new vector store instance. + + Args: + client (Union[gel.Client, gel.AsyncIOClient]): Gel client instance, + collection_name (str): Collection name, + record_type (str): Schema type (table name) for storing records, + embedding_model (EmbeddingModel[T]): Embedding model. + """ + + self.client = client + self.collection_name = collection_name + self.record_type = record_type + + +class VectorStore(BaseVectorStore[T]): + """ + A framework-agnostic interface for interacting with Gel's ext::vectorstore. + + This class provides methods for storing, retrieving, and searching + vector embeddings. It follows vector database conventions and supports + different embedding models. + """ + + def __init__( + self, + client: gel.Client, + embedding_model: Optional[EmbeddingModel] = None, + collection_name: str = "default", + record_type: str = "ext::vectorstore::DefaultRecord", + ): + super().__init__( + client=client, + collection_name=collection_name, + record_type=record_type, + ) + self.embedding_model = embedding_model + + def add_records(self, *records: Record) -> List[uuid.UUID]: + """ + Add multiple items to the vector store in a single transaction. + Embeddinsg will be generated and stored for all items. + + Args: + *records (Record): Records to insert. Each contains: + - text (str): The text content to be embedded, + - metadata (Optional[Dict[str, Any]]): Additional data to store. + + Returns: + List[uuid.UUID]: List of database record IDs for the inserted items. + """ + + if not self.embedding_model: + raise ValueError("Embedding model is not set.") + + vectors = [record.to_vector(self.embedding_model) for record in records] + return self.add_vectors(*vectors) + + def add_vectors(self, *vectors: Vector) -> List[uuid.UUID]: + """Add pre-computed vector embeddings to the store. + + Use this method when you have already generated embeddings and want to + store them directly without re-computing them. + + Args: + *vectors (Vector): Vectors to insert. Each contains: + - embedding (Optional[Vector]): Pre-computed embeddings, + - text (Optional[str]): Original text content, + - metadata (Optional[Dict[str, Any]]): Additional data to store. + + Returns: + List[uuid.UUID]: List of database record IDs for the inserted items. + """ + + results = self.client.query( + query=BATCH_ADD_QUERY.format( + record_type=quote.quote_ident(self.record_type) + ), + collection_name=self.collection_name, + items=json.dumps( + [ + { + "text": vector.text, + "embedding": vector.embedding, + "metadata": _serialize_metadata(vector.metadata), + } + for vector in vectors + ] + ), + ) + return [result.id for result in results] + + def delete(self, *ids: uuid.UUID) -> List[uuid.UUID]: + """Delete records from the vector store by their IDs. + + Args: + *ids (uuid.UUID): Ids of records to delete. + + Returns: + List[uuid.UUID]: List of deleted record IDs. + """ + + results = self.client.query( + query=DELETE_BY_IDS_QUERY.format( + record_type=quote.quote_ident(self.record_type) + ), + collection_name=self.collection_name, + ids=ids, + ) + return [result.id for result in results] + + def get_by_ids(self, *ids: uuid.UUID) -> List[Vector]: + """Retrieve specific vectors by their IDs. + + Args: + *ids (uuid.UUID): IDs of vectors to retrieve. + + Returns: + List[Vector]: List of retrieved vectors. Each contains: + - id (uuid.UUID): Vector's unique identifier, + - text (Optional[str]): Text content, + - embedding (Optional[Vector]): Stored vector embedding, + - metadata (Optional[Dict[str, Any]]): Any associated metadata. + """ + + results = self.client.query( + query=GET_BY_IDS_QUERY.format( + record_type=quote.quote_ident(self.record_type) + ), + collection_name=self.collection_name, + ids=ids, + ) + + return [ + Vector( + id=result.id, + text=result.text, + embedding=result.embedding, + metadata=_deserialize_metadata(result.metadata), + ) + for result in results + ] + + def search_by_record( + self, + item: T, + filters: Optional[CompositeFilter] = None, + limit: Optional[int] = 4, + ) -> List[SearchResult]: + """Search for similar vectors in the vector store. + + This method: + 1. Generates an embedding for the input item, + 2. Finds items with similar embeddings, + 3. Optionally filters results based on metadata, + 4. Returns the most similar items up to the specified limit. + + Args: + item (T): The query item to find similar matches for. + Must be compatible with the embedding model's type. + filters (Optional[CompositeFilter]): Metadata-based filters to use. + limit (Optional[int]): Max number of results to return. + Defaults to 4. + + Returns: + List[SearchResult]: List of similar items, ordered by similarity. + Each contains: + - id (uuid.UUID): Item's unique identifier, + - text (Optional[str]): Text content, + - embedding (Vector): Stored vector embedding, + - metadata (Optional[Dict[str, Any]]): Any associated metadata, + - cosine_similarity (float): Similarity score. + """ + + vector = self.embedding_model.generate(item) + filter_expression = ( + f"filter {get_filter_clause(filters)}" if filters else "" + ) + return self.search_by_vector( + vector=vector, filter_expression=filter_expression, limit=limit + ) + + def search_by_vector( + self, + vector: Vector, + filter_expression: str = "", + limit: Optional[int] = 4, + ) -> List[SearchResult]: + """Search using a pre-computed vector embedding. + + Useful when you have already computed the embedding or want to search + with a modified/combined embedding vector. + + Args: + vector (Vector): The query embedding to search with. + Must match the dimensionality of stored embeddings. + filter_expression (str): Filter expression for metadata filtering. + limit (Optional[int]): Max num of results to return. Defaults to 4. + + Returns: + List[SearchResult]: List of similar items, ordered by similarity. + Each contains: + - id (uuid.UUID): Item's unique identifier, + - text (Optional[str]): Text content, + - embedding (Vector): Stored vector embedding, + - metadata (Optional[Dict[str, Any]]): Any associated metadata, + - cosine_similarity (float): Similarity score. + """ + + results = self.client.query( + query=SEARCH_QUERY.format( + record_type=quote.quote_ident(self.record_type), + filter_expression=filter_expression, + ), + collection_name=self.collection_name, + query_embedding=vector, + limit=limit, + ) + return [ + SearchResult( + id=result.id, + text=result.text, + embedding=list(result.embedding) if result.embedding else None, + metadata=_deserialize_metadata(result.metadata), + cosine_similarity=result.cosine_similarity, + ) + for result in results + ] + + def update_vector( + self, + id: uuid.UUID, + text: Union[str, None, object] = _sentinel, + embedding: Union[Vector, None, object] = _sentinel, + metadata: Union[Dict[str, Any], None, object] = _sentinel, + ) -> Optional[uuid.UUID]: + """Update an existing item in the vector store. + + Only specified fields will be updated. If text is provided + but not embedding, a new embedding will be automatically + generated using the embedding model you provided. + + Args: + - id (uuid.UUID): ID of the item to update, + - text (Optional[str]): New text content. If provided without + embedding, a new embedding will be generated. + - embedding (Optional[Vector]): New vector embedding. + - metadata (Optional[Dict[str, Any]]): New metadata to store. + Completely replaces existing metadata. + Returns: + Optional[uuid.UUID]: Updated record's ID if found and updated. + Raises: + ValueError: If no fields are specified for update. + """ + + updates = [] + + if text is not _sentinel: + updates.append("text") + if embedding is not _sentinel: + updates.append("embedding") + if metadata is not _sentinel: + updates.append("metadata") + + if not updates: + raise ValueError("No fields specified for update.") + + if ( + "text" in updates + and text is not None + and "embedding" not in updates + ): + updates.append("embedding") + embedding = self.embedding_model.generate(text) + + result = self.client.query_single( + query=UPDATE_QUERY.format( + record_type=quote.quote_ident(self.record_type) + ), + collection_name=self.collection_name, + id=id, + updates=list(updates), + text=text if text is not _sentinel else None, + embedding=embedding if embedding is not _sentinel else None, + metadata=( + _serialize_metadata(metadata) + if metadata is not _sentinel + else None + ), + ) + return result.id if result else None + + +class AsyncEmbeddingModel(abc.ABC, Generic[T]): + """ + Async embedding model. Any async embedding model used with `VectorStore` + must implement this interface. The model is expected to convert input data + (text, images, etc.) into a numerical vector representation. + """ + + @abc.abstractmethod + async def generate(self, item: T) -> Vector: ... + + +class AsyncVectorStore(BaseVectorStore[T]): + """ + A framework-agnostic interface for interacting with Gel's ext::vectorstore. + + This class provides methods for storing, retrieving, and searching + vector embeddings. It follows vector database conventions and supports + different embedding models. + """ + + def __init__( + self, + embedding_model: Optional[AsyncEmbeddingModel] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.embedding_model = embedding_model + + async def add_records(self, *records: Record) -> List[uuid.UUID]: + if not self.embedding_model: + raise ValueError("Embedding model is not set") + + vectors = [ + record.ato_vector(self.embedding_model) for record in records + ] + return self.add_vectors(*vectors) + + async def add_vectors(self, *vectors: Vector) -> List[uuid.UUID]: + results = await self.client.query( + query=BATCH_ADD_QUERY.format( + record_type=quote.quote_ident(self.record_type) + ), + collection_name=self.collection_name, + items=json.dumps( + [ + { + "text": vector.text, + "embedding": vector.embedding, + "metadata": _serialize_metadata(vector.metadata), + } + for vector in vectors + ] + ), + ) + return [result.id for result in results] + + async def delete(self, *ids: uuid.UUID) -> List[uuid.UUID]: + results = await self.client.query( + query=DELETE_BY_IDS_QUERY.format( + record_type=quote.quote_ident(self.record_type) + ), + collection_name=self.collection_name, + ids=ids, + ) + return [result.id for result in results] + + async def get_by_ids(self, *ids: uuid.UUID) -> List[Record]: + results = await self.client.query( + query=GET_BY_IDS_QUERY.format( + record_type=quote.quote_ident(self.record_type) + ), + collection_name=self.collection_name, + ids=ids, + ) + + return [ + Record( + id=result.id, + text=result.text, + embedding=result.embedding, + metadata=_deserialize_metadata(result.metadata), + ) + for result in results + ] + + async def search_by_record( + self, + item: T, + filters: Optional[CompositeFilter] = None, + limit: Optional[int] = 4, + ) -> List[SearchResult]: + vector = await self.embedding_model.generate(item) + filter_expression = ( + f"filter {get_filter_clause(filters)}" if filters else "" + ) + return self.search_by_vector( + vector=vector, filter_expression=filter_expression, limit=limit + ) + + async def search_by_vector( + self, + vector: Vector, + filter_expression: str = "", + limit: Optional[int] = 4, + ) -> List[SearchResult]: + results = await self.client.query( + query=SEARCH_QUERY.format( + record_type=quote.quote_ident(self.record_type), + filter_expression=filter_expression, + ), + collection_name=self.collection_name, + query_embedding=vector, + limit=limit, + ) + return [ + SearchResult( + id=result.id, + text=result.text, + embedding=list(result.embedding) if result.embedding else None, + metadata=_deserialize_metadata(result.metadata), + cosine_similarity=result.cosine_similarity, + ) + for result in results + ] + + async def update_vector( + self, + id: uuid.UUID, + text: Union[str, None, object] = _sentinel, + embedding: Union[Vector, None, object] = _sentinel, + metadata: Union[Dict[str, Any], None, object] = _sentinel, + ) -> Optional[uuid.UUID]: + updates = [] + + if text is not _sentinel: + updates.append("text") + if embedding is not _sentinel: + updates.append("embedding") + if metadata is not _sentinel: + updates.append("metadata") + + if not updates: + raise ValueError("No fields specified for update.") + + if ( + "text" in updates + and text is not None + and "embedding" not in updates + ): + updates.append("embedding") + embedding = await self.embedding_model.generate(text) + + result = await self.client.query_single( + query=UPDATE_QUERY.format( + record_type=quote.quote_ident(self.record_type) + ), + collection_name=self.collection_name, + id=id, + updates=list(updates), + text=text if text is not _sentinel else None, + embedding=embedding if embedding is not _sentinel else None, + metadata=( + _serialize_metadata(metadata) + if metadata is not _sentinel + else None + ), + ) + return result.id if result else None diff --git a/gel/quote.py b/gel/quote.py new file mode 100644 index 00000000..3e7aa1ae --- /dev/null +++ b/gel/quote.py @@ -0,0 +1,175 @@ +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re + + +_re_ident = re.compile( + r"""(?x) + [^\W\d]\w* # alphanumeric identifier +""" +) + +_re_ident_or_num = re.compile( + r"""(?x) + [^\W\d]\w* # alphanumeric identifier + | + ([1-9]\d* | 0) # purely integer identifier +""" +) + +_reserved_keyword = { + "like", + "do", + "listen", + "grant", + "anytype", + "ilike", + "single", + "__edgedbsys__", + "end", + "set", + "never", + "typeof", + "start", + "configure", + "rollback", + "when", + "__edgedbtpl__", + "or", + "__source__", + "filter", + "global", + "case", + "introspect", + "__default__", + "not", + "begin", + "over", + "if", + "lock", + "refresh", + "else", + "alter", + "notify", + "distinct", + "and", + "module", + "offset", + "drop", + "is", + "discard", + "anyobject", + "import", + "group", + "__subject__", + "limit", + "match", + "anyarray", + "insert", + "get", + "administer", + "delete", + "__old__", + "exists", + "true", + "select", + "analyze", + "by", + "move", + "load", + "deallocate", + "partition", + "with", + "window", + "in", + "false", + "raise", + "revoke", + "anytuple", + "commit", + "update", + "for", + "describe", + "variadic", + "fetch", + "__specified__", + "optional", + "explain", + "__new__", + "create", + "prepare", + "check", + "extending", + "detached", + "on", +} + + +def escape_string(s: str) -> str: + # characters escaped according to + # https://www.edgedb.com/docs/reference/edgeql/lexical#strings + result = s + + # escape backslash first + result = result.replace("\\", "\\\\") + + result = result.replace("'", "\\'") + result = result.replace("\b", "\\b") + result = result.replace("\f", "\\f") + result = result.replace("\n", "\\n") + result = result.replace("\r", "\\r") + result = result.replace("\t", "\\t") + + return result + + +def quote_literal(string: str) -> str: + return "'" + escape_string(string) + "'" + + +def needs_quoting(string: str, allow_reserved: bool, allow_num: bool) -> bool: + if not string or string.startswith("@") or "::" in string: + # some strings are illegal as identifiers and as such don't + # require quoting + return False + + r = _re_ident_or_num if allow_num else _re_ident + isalnum = r.fullmatch(string) + + string = string.lower() + + is_reserved = string in _reserved_keyword + + return not isalnum or (not allow_reserved and is_reserved) + + +def _quote_ident(string: str) -> str: + return "`" + string.replace("`", "``") + "`" + + +def quote_ident( + string: str, + *, + force: bool = False, + allow_reserved: bool = False, + allow_num: bool = False, +) -> str: + if force or needs_quoting(string, allow_reserved, allow_num): + return _quote_ident(string) + else: + return string diff --git a/tests/ai/__init__.py b/tests/ai/__init__.py new file mode 100644 index 00000000..90213119 --- /dev/null +++ b/tests/ai/__init__.py @@ -0,0 +1,17 @@ +# +# This source file is part of the Gel open source project. +# +# Copyright 2016-present MagicStack Inc. and the Gel authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/ai/schema/vectorstore.esdl b/tests/ai/schema/vectorstore.esdl new file mode 100644 index 00000000..e69de29b diff --git a/tests/ai/test_ai_metadata_filter.py b/tests/ai/test_ai_metadata_filter.py new file mode 100644 index 00000000..d00feeee --- /dev/null +++ b/tests/ai/test_ai_metadata_filter.py @@ -0,0 +1,61 @@ +import unittest + +from gel.ai import ( + MetadataFilter, + CompositeFilter, + FilterOperator, + FilterCondition, +) + + +class TestAICompositeFilter(unittest.TestCase): + + # Test MetadataFilter with default EQ operator + def test_metadata_EQ_filter(self): + filter_obj = MetadataFilter( + key="category", value="science", operator=FilterOperator.EQ + ) + expected_repr = ( + 'MetadataFilter(key="category", value="science", operator="=")' + ) + + self.assertEqual(repr(filter_obj), expected_repr) + self.assertEqual(filter_obj.key, "category") + self.assertEqual(filter_obj.value, "science") + self.assertEqual(filter_obj.operator, FilterOperator.EQ) + + # Test MetadataFilter with NE operator + def test_metadata_NE_filter(self): + filter_obj = MetadataFilter( + key="author", value="Alice", operator=FilterOperator.NE + ) + expected_repr = ( + 'MetadataFilter(key="author", value="Alice", operator="!=")' + ) + + self.assertEqual(repr(filter_obj), expected_repr) + + # Test CompositeFilter with AND condition + def test_metadata_filters_and_condition(self): + filters = CompositeFilter( + [ + MetadataFilter( + key="category", value="AI", operator=FilterOperator.EQ + ), + MetadataFilter( + key="views", value=1000, operator=FilterOperator.GT + ), + ], + condition=FilterCondition.AND, + ) + expected_repr = ( + f'CompositeFilter(condition="and", filters=[' + f'MetadataFilter(key="category", value="AI", operator="="), ' + f'MetadataFilter(key="views", value=1000, operator=">")])' + ) + + self.assertEqual(repr(filters), expected_repr) + self.assertEqual(len(filters.filters), 2) + self.assertEqual(filters.condition, FilterCondition.AND) + self.assertEqual(filters.filters[1].operator, FilterOperator.GT) + self.assertEqual(filters.filters[1].value, 1000) diff --git a/tests/ai/test_ai_vectorstore.py b/tests/ai/test_ai_vectorstore.py new file mode 100644 index 00000000..839d4610 --- /dev/null +++ b/tests/ai/test_ai_vectorstore.py @@ -0,0 +1,274 @@ +# +# This source file is part of the Gel open source project. +# +# Copyright 2024-present MagicStack Inc. and the Gel authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import unittest +import uuid + +from gel import _testbase as tb +from gel import create_client + +from gel.ai import ( + create_vstore, + EmbeddingModel, + Vector, + Record, + SearchResult, + MetadataFilter, + CompositeFilter, + FilterOperator, +) + +# records to be reused in tests +records = [ + Record( + text="""EdgeQL is a next-generation query language designed + to match SQL in power and surpass it in terms of clarity, + brevity, and intuitiveness. It's used to query the database, + insert/update/delete data, modify/introspect the schema, + manage transactions, and more.""", + metadata={"category": "edgeql"}, + ), + Record( + text="""Gel schemas are declared using SDL (Gel's + Schema Definition Language). Your schema is defined inside + .esdl files. It's common to define your entire schema in a + single file called default.esdl, but you can split it across + multiple files if you wish.""", + metadata={"category": "schema"}, + ), + Record( + text="""Object types can contain computed properties and + links. Computed properties and links are not persisted in the + database. Instead, they are evaluated on the fly whenever + that field is queried""", + metadata={"category": "schema"}, + ), +] + + +class MockEmbeddingModel(EmbeddingModel[str]): + """Mocked embedding model returns fixed embeddings.""" + + def generate(self, item): + return [0.1] * 1536 + + +class TestAIVectorStore(tb.SyncQueryTestCase): + VECTORSTORE_VER = None + + SCHEMA = os.path.join( + os.path.dirname(__file__), "schema", "vectorstore.esdl" + ) + + SETUP = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + + cls.VECTORSTORE_VER = cls.client.query_single( + """ + select assert_single(( + select sys::ExtensionPackage filter .name = 'vectorstore' + )).version + """ + ) + + if cls.VECTORSTORE_VER is None: + raise unittest.SkipTest("feature not implemented") + + cls.client.execute( + """ + create extension pgvector; + create extension vectorstore; + """ + ) + + @classmethod + def tearDownClass(cls): + try: + cls.client.execute( + """ + drop extension vectorstore; + drop extension pgvector; + """ + ) + finally: + super().tearDownClass() + + def setUp(self): + super().setUp() + + self.vectorstore = create_vstore( + client=self.client, + embedding_model=MockEmbeddingModel(), + ) + + def tearDown(self): + try: + self._clean_vectorstore() + finally: + super().tearDown() + + def _clean_vectorstore(self): + """Helper method to remove all records from the vectorstore.""" + + self.client.execute( + f""" + delete {self.vectorstore.record_type} + filter .collection = $collection_name; + """, + collection_name=self.vectorstore.collection_name, + ) + + # will be used in tests for comparing embeddings + def assertListAlmostEqual(self, first, second, places=7): + """Assert that two lists of floats are almost equal.""" + self.assertEqual(len(first), len(second)) + for a, b in zip(first, second): + self.assertAlmostEqual(a, b, places=places) + + def test_add_get_and_delete(self): + text = """Gel is an open-source database engineered to advance SQL + into a sophisticated graph data model, supporting composable + hierarchical queries and accelerated development cycles.""" + + # insert a record + ids = self.vectorstore.add_records(Record(text=text)) + vector_id = ids[0] + self.assertIsNotNone(ids) + self.assertEqual(len(ids), 1) + self.assertIsInstance(ids[0], uuid.UUID) + + # verify the inserted records + vectors = self.vectorstore.get_by_ids(vector_id) + self.assertEqual(len(vectors), 1) + self.assertEqual(vectors[0].id, vector_id) + self.assertIsInstance(vectors[0], Vector) + + # delete a vector + deleted_vectors_ids = self.vectorstore.delete(vector_id) + self.assertEqual(deleted_vectors_ids[0], vector_id) + self.assertIsInstance(deleted_vectors_ids[0], uuid.UUID) + + # verify that the vector is deleted + vectors = self.vectorstore.get_by_ids(vector_id) + self.assertEqual(len(vectors), 0) + + def test_add_multiple(self): + ids = self.vectorstore.add_records(*records) + self.assertEqual(len(ids), 3) + + def test_search_no_filters(self): + self.vectorstore.add_records(*records) + + query = "Tell me about edgeql" + results = self.vectorstore.search_by_record(item=query, limit=2) + + self.assertEqual(len(results), 2) + # since we're using a mock embedding model that returns the same vector + # for all inputs, the results will be ordered by insertion order + for result in results: + self.assertIsInstance(result, SearchResult) + self.assertIsNotNone(result.text) + self.assertIsNotNone(result.cosine_similarity) + self.assertIsNotNone(result.metadata) + + def test_search_with_filters(self): + self.vectorstore.add_records(*records) + + filters = CompositeFilter( + filters=[ + MetadataFilter( + key="category", operator=FilterOperator.EQ, value="schema" + ) + ] + ) + + query = "How do I use computed properties?" + results = self.vectorstore.search_by_record( + item=query, filters=filters, limit=3 + ) + + self.assertEqual(len(results), 2) + # verify all results are from the schema category + for result in results: + self.assertIsInstance(result, SearchResult) + self.assertEqual(result.metadata["category"], "schema") + + def test_update_vector(self): + # insert a record + ids = self.vectorstore.add_vectors(Vector(embedding=[0.1] * 1536)) + vector_id = ids[0] + + # verify the inserted record + vector = self.vectorstore.get_by_ids(vector_id)[0] + self.assertIsInstance(vector, Vector) + self.assertIsNone(vector.metadata) + self.assertIsNone(vector.text) + self.assertListAlmostEqual(vector.embedding, [0.1] * 1536) + + # update just metadata + new_metadata = {"test": "test"} + updated_id = self.vectorstore.update_vector(id=vector_id, metadata=new_metadata) + self.assertEqual(updated_id, vector_id) + + # verify the updated vector + vector = self.vectorstore.get_by_ids(vector_id)[0] + self.assertIsNone(vector.text) + self.assertEqual(vector.metadata, new_metadata) + + # update both text & embedding + new_text = "Update text content and embedding" + new_embedding = [0.0] * 1536 + self.vectorstore.update_vector( + id=vector_id, text=new_text, embedding=new_embedding + ) + + # verify the updated record + vector = self.vectorstore.get_by_ids(vector_id)[0] + self.assertEqual(vector.metadata, new_metadata) + self.assertEqual(vector.text, new_text) + self.assertListAlmostEqual(vector.embedding, new_embedding) + + # update just text: embedding should be auto-generated + new_text = "Update just text content" + self.vectorstore.update_vector(id=vector_id, text=new_text) + + # verify the update + vector = self.vectorstore.get_by_ids(vector_id)[0] + self.assertEqual(vector.text, new_text) + self.assertEqual(vector.metadata, new_metadata) + + # remove text and metadata + self.vectorstore.update_vector(id=vector_id, text=None, metadata=None) + + # verify the update + vector = self.vectorstore.get_by_ids(vector_id)[0] + self.assertIsNone(vector.text) + self.assertIsNone(vector.metadata) + + def test_update_nonexistent_vector(self): + fake_id = uuid.uuid4() + updated = self.vectorstore.update_vector(id=fake_id, text="This shouldn't work") + self.assertIsNone(updated) + + def test_update_no_fields_specified(self): + with self.assertRaises(ValueError): + self.vectorstore.update_vector(id=uuid.uuid4()) diff --git a/tests/dbsetup/vectorstore.edgeql b/tests/dbsetup/vectorstore.edgeql new file mode 100644 index 00000000..280232b3 --- /dev/null +++ b/tests/dbsetup/vectorstore.edgeql @@ -0,0 +1,11 @@ +for i in range_unpack(range(1, 8)) +union ( + insert ext::vectorstore::DefaultRecord { + collection := "test", + external_id := "00000000-0000-0000-0000-00000000000" ++ i, + text := "some text", + embedding := (array_fill(1, i * 192) ++ array_fill(0, (8 - i) * 192)), + metadata := to_json('{ "str_field": ' ++ ('"least_similar"' if i <= 4 else '"most_similar"') ++ '}'), + } +); +