diff --git a/examples/wait_for_knowledge_base.py b/examples/wait_for_knowledge_base.py new file mode 100644 index 00000000..739ff80e --- /dev/null +++ b/examples/wait_for_knowledge_base.py @@ -0,0 +1,64 @@ +""" +Example demonstrating how to use the wait_for_database helper function. + +This example shows how to: +1. Create a knowledge base +2. Wait for its database to be ready +3. Handle errors and timeouts appropriately +""" + +import os + +from gradient import Gradient +from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError, KnowledgeBaseDatabaseError + + +def main() -> None: + """Create a knowledge base and wait for its database to be ready.""" + # Initialize the Gradient client + # Note: DIGITALOCEAN_ACCESS_TOKEN must be set in your environment + client = Gradient( + access_token=os.environ.get("DIGITALOCEAN_ACCESS_TOKEN"), + ) + + # Create a knowledge base + # Replace these values with your actual configuration + kb_response = client.knowledge_bases.create( + name="My Knowledge Base", + region="nyc1", # Choose your preferred region + embedding_model_uuid="your-embedding-model-uuid", # Use your embedding model UUID + ) + + if not kb_response.knowledge_base or not kb_response.knowledge_base.uuid: + print("Failed to create knowledge base") + return + + kb_uuid = kb_response.knowledge_base.uuid + print(f"Created knowledge base: {kb_uuid}") + + try: + # Wait for the database to be ready + # Default: 10 minute timeout, 5 second poll interval + print("Waiting for database to be ready...") + result = client.knowledge_bases.wait_for_database(kb_uuid) + print(f"Database status: {result.database_status}") # "ONLINE" + print("Knowledge base is ready!") + + # Alternative: Custom timeout and poll interval + # result = client.knowledge_bases.wait_for_database( + # kb_uuid, + # timeout=900.0, # 15 minutes + # poll_interval=10.0 # Check every 10 seconds + # ) + + except KnowledgeBaseDatabaseError as e: + # Database entered a failed state (DECOMMISSIONED or UNHEALTHY) + print(f"Database failed: {e}") + + except KnowledgeBaseTimeoutError as e: + # Database did not become ready within the timeout period + print(f"Timeout: {e}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 54be5413..0e83a25b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -253,5 +253,4 @@ known-first-party = ["gradient", "tests"] [tool.ruff.lint.per-file-ignores] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] -"tests/**.py" = ["T201", "T203"] "examples/**.py" = ["T201", "T203"] diff --git a/src/gradient/resources/knowledge_bases/__init__.py b/src/gradient/resources/knowledge_bases/__init__.py index 80d04328..353dc05c 100644 --- a/src/gradient/resources/knowledge_bases/__init__.py +++ b/src/gradient/resources/knowledge_bases/__init__.py @@ -18,6 +18,8 @@ ) from .knowledge_bases import ( KnowledgeBasesResource, + KnowledgeBaseTimeoutError, + KnowledgeBaseDatabaseError, AsyncKnowledgeBasesResource, KnowledgeBasesResourceWithRawResponse, AsyncKnowledgeBasesResourceWithRawResponse, @@ -40,6 +42,8 @@ "AsyncIndexingJobsResourceWithStreamingResponse", "KnowledgeBasesResource", "AsyncKnowledgeBasesResource", + "KnowledgeBaseDatabaseError", + "KnowledgeBaseTimeoutError", "KnowledgeBasesResourceWithRawResponse", "AsyncKnowledgeBasesResourceWithRawResponse", "KnowledgeBasesResourceWithStreamingResponse", diff --git a/src/gradient/resources/knowledge_bases/knowledge_bases.py b/src/gradient/resources/knowledge_bases/knowledge_bases.py index 00fa0659..4325148c 100644 --- a/src/gradient/resources/knowledge_bases/knowledge_bases.py +++ b/src/gradient/resources/knowledge_bases/knowledge_bases.py @@ -2,6 +2,8 @@ from __future__ import annotations +import time +import asyncio from typing import Iterable import httpx @@ -40,7 +42,24 @@ from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse -__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"] +__all__ = [ + "KnowledgeBasesResource", + "AsyncKnowledgeBasesResource", + "KnowledgeBaseDatabaseError", + "KnowledgeBaseTimeoutError", +] + + +class KnowledgeBaseDatabaseError(Exception): + """Raised when a knowledge base database enters a failed state.""" + + pass + + +class KnowledgeBaseTimeoutError(Exception): + """Raised when waiting for a knowledge base database times out.""" + + pass class KnowledgeBasesResource(SyncAPIResource): @@ -330,6 +349,81 @@ def delete( cast_to=KnowledgeBaseDeleteResponse, ) + def wait_for_database( + self, + uuid: str, + *, + timeout: float = 600.0, + poll_interval: float = 5.0, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + ) -> KnowledgeBaseRetrieveResponse: + """ + Poll the knowledge base until the database status is ONLINE or a failed state is reached. + + This helper function repeatedly calls retrieve() to check the database_status field. + It will wait for the database to become ONLINE, or raise an exception if it enters + a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded. + + Args: + uuid: The knowledge base UUID to poll + + timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes) + + poll_interval: Time to wait between polls in seconds (default: 5 seconds) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + Returns: + The final KnowledgeBaseRetrieveResponse when the database status is ONLINE + + Raises: + KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY) + + KnowledgeBaseTimeoutError: If the timeout is exceeded before the database becomes ONLINE + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + failed_states = {"DECOMMISSIONED", "UNHEALTHY"} + + while True: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise KnowledgeBaseTimeoutError( + f"Timeout waiting for knowledge base database to become ready. " + f"Database did not reach ONLINE status within {timeout} seconds." + ) + + response = self.retrieve( + uuid, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + ) + + status = response.database_status + + if status == "ONLINE": + return response + + if status in failed_states: + raise KnowledgeBaseDatabaseError(f"Knowledge base database entered failed state: {status}") + + # Sleep before next poll, but don't exceed timeout + remaining_time = timeout - elapsed + sleep_time = min(poll_interval, remaining_time) + if sleep_time > 0: + time.sleep(sleep_time) + class AsyncKnowledgeBasesResource(AsyncAPIResource): @cached_property @@ -618,6 +712,81 @@ async def delete( cast_to=KnowledgeBaseDeleteResponse, ) + async def wait_for_database( + self, + uuid: str, + *, + timeout: float = 600.0, + poll_interval: float = 5.0, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + ) -> KnowledgeBaseRetrieveResponse: + """ + Poll the knowledge base until the database status is ONLINE or a failed state is reached. + + This helper function repeatedly calls retrieve() to check the database_status field. + It will wait for the database to become ONLINE, or raise an exception if it enters + a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded. + + Args: + uuid: The knowledge base UUID to poll + + timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes) + + poll_interval: Time to wait between polls in seconds (default: 5 seconds) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + Returns: + The final KnowledgeBaseRetrieveResponse when the database status is ONLINE + + Raises: + KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY) + + KnowledgeBaseTimeoutError: If the timeout is exceeded before the database becomes ONLINE + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + failed_states = {"DECOMMISSIONED", "UNHEALTHY"} + + while True: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise KnowledgeBaseTimeoutError( + f"Timeout waiting for knowledge base database to become ready. " + f"Database did not reach ONLINE status within {timeout} seconds." + ) + + response = await self.retrieve( + uuid, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + ) + + status = response.database_status + + if status == "ONLINE": + return response + + if status in failed_states: + raise KnowledgeBaseDatabaseError(f"Knowledge base database entered failed state: {status}") + + # Sleep before next poll, but don't exceed timeout + remaining_time = timeout - elapsed + sleep_time = min(poll_interval, remaining_time) + if sleep_time > 0: + await asyncio.sleep(sleep_time) + class KnowledgeBasesResourceWithRawResponse: def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: @@ -638,6 +807,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: self.delete = to_raw_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = to_raw_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> DataSourcesResourceWithRawResponse: @@ -667,6 +839,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None: self.delete = async_to_raw_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = async_to_raw_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> AsyncDataSourcesResourceWithRawResponse: @@ -696,6 +871,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: self.delete = to_streamed_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = to_streamed_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> DataSourcesResourceWithStreamingResponse: @@ -725,6 +903,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None: self.delete = async_to_streamed_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = async_to_streamed_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> AsyncDataSourcesResourceWithStreamingResponse: diff --git a/tests/api_resources/test_knowledge_bases.py b/tests/api_resources/test_knowledge_bases.py index 62965775..a42277e4 100644 --- a/tests/api_resources/test_knowledge_bases.py +++ b/tests/api_resources/test_knowledge_bases.py @@ -275,6 +275,102 @@ def test_path_params_delete(self, client: Gradient) -> None: "", ) + @parametrize + def test_method_wait_for_database_success(self, client: Gradient) -> None: + """Test wait_for_database with successful database status transition.""" + from unittest.mock import Mock + + call_count = [0] + + def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 + call_count[0] += 1 + response = Mock() + # Simulate CREATING -> ONLINE transition + response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" + return response + + client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] + + result = client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + assert result.database_status == "ONLINE" + assert call_count[0] == 2 + + @parametrize + def test_method_wait_for_database_failed_state(self, client: Gradient) -> None: + """Test wait_for_database with failed database status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 + response = Mock() + response.database_status = "UNHEALTHY" + return response + + client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] + + with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + def test_method_wait_for_database_timeout(self, client: Gradient) -> None: + """Test wait_for_database with timeout.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError + + def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 + response = Mock() + response.database_status = "CREATING" + return response + + client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] + + with pytest.raises(KnowledgeBaseTimeoutError): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=0.3, + poll_interval=0.1, + ) + + @parametrize + def test_method_wait_for_database_decommissioned(self, client: Gradient) -> None: + """Test wait_for_database with DECOMMISSIONED status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 + response = Mock() + response.database_status = "DECOMMISSIONED" + return response + + client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] + + with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + def test_path_params_wait_for_database(self, client: Gradient) -> None: + """Test wait_for_database validates uuid parameter.""" + with pytest.raises(ValueError, match=r"Expected a non-empty value for `uuid` but received ''"): + client.knowledge_bases.wait_for_database( + "", + ) + class TestAsyncKnowledgeBases: parametrize = pytest.mark.parametrize( @@ -532,3 +628,99 @@ async def test_path_params_delete(self, async_client: AsyncGradient) -> None: await async_client.knowledge_bases.with_raw_response.delete( "", ) + + @parametrize + async def test_method_wait_for_database_success(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with successful database status transition.""" + from unittest.mock import Mock + + call_count = [0] + + async def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 + call_count[0] += 1 + response = Mock() + # Simulate CREATING -> ONLINE transition + response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] + + result = await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + assert result.database_status == "ONLINE" + assert call_count[0] == 2 + + @parametrize + async def test_method_wait_for_database_failed_state(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with failed database status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + async def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 + response = Mock() + response.database_status = "UNHEALTHY" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] + + with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + async def test_method_wait_for_database_timeout(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with timeout.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError + + async def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 + response = Mock() + response.database_status = "CREATING" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] + + with pytest.raises(KnowledgeBaseTimeoutError): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=0.3, + poll_interval=0.1, + ) + + @parametrize + async def test_method_wait_for_database_decommissioned(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with DECOMMISSIONED status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + async def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 + response = Mock() + response.database_status = "DECOMMISSIONED" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] + + with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + async def test_path_params_wait_for_database(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database validates uuid parameter.""" + with pytest.raises(ValueError, match=r"Expected a non-empty value for `uuid` but received ''"): + await async_client.knowledge_bases.wait_for_database( + "", + ) diff --git a/tests/test_client.py b/tests/test_client.py index ddf1c4db..846c0bb6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -287,9 +287,9 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic add_leak(leaks, diff) if leaks: for leak in leaks: - print("MEMORY LEAK:", leak) + print("MEMORY LEAK:", leak) # noqa: T201 for frame in leak.traceback: - print(frame) + print(frame) # noqa: T201 raise AssertionError() def test_request_timeout(self) -> None: @@ -1304,9 +1304,9 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic add_leak(leaks, diff) if leaks: for leak in leaks: - print("MEMORY LEAK:", leak) + print("MEMORY LEAK:", leak) # noqa: T201 for frame in leak.traceback: - print(frame) + print(frame) # noqa: T201 raise AssertionError() async def test_request_timeout(self) -> None: diff --git a/tests/test_files.py b/tests/test_files.py index 4d9f4066..54210e83 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -11,34 +11,34 @@ def test_pathlib_includes_file_name() -> None: result = to_httpx_files({"file": readme_path}) - print(result) + print(result) # noqa: T201 assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) def test_tuple_input() -> None: result = to_httpx_files([("file", readme_path)]) - print(result) + print(result) # noqa: T201 assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) @pytest.mark.asyncio async def test_async_pathlib_includes_file_name() -> None: result = await async_to_httpx_files({"file": readme_path}) - print(result) + print(result) # noqa: T201 assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) @pytest.mark.asyncio async def test_async_supports_anyio_path() -> None: result = await async_to_httpx_files({"file": anyio.Path(readme_path)}) - print(result) + print(result) # noqa: T201 assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) @pytest.mark.asyncio async def test_async_tuple_input() -> None: result = await async_to_httpx_files([("file", readme_path)]) - print(result) + print(result) # noqa: T201 assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes())))