Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions examples/wait_for_knowledge_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Example demonstrating how to use the wait_for_database helper function.

This example shows how to:
1. Create a knowledge base
2. Wait for its database to be ready
3. Handle errors and timeouts appropriately
"""

import os

from gradient import Gradient
from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError, KnowledgeBaseDatabaseError


def main() -> None:
"""Create a knowledge base and wait for its database to be ready."""
# Initialize the Gradient client
# Note: DIGITALOCEAN_ACCESS_TOKEN must be set in your environment
client = Gradient(
access_token=os.environ.get("DIGITALOCEAN_ACCESS_TOKEN"),
)

# Create a knowledge base
# Replace these values with your actual configuration
kb_response = client.knowledge_bases.create(
name="My Knowledge Base",
region="nyc1", # Choose your preferred region
embedding_model_uuid="your-embedding-model-uuid", # Use your embedding model UUID
)

if not kb_response.knowledge_base or not kb_response.knowledge_base.uuid:
print("Failed to create knowledge base")
return

kb_uuid = kb_response.knowledge_base.uuid
print(f"Created knowledge base: {kb_uuid}")

try:
# Wait for the database to be ready
# Default: 10 minute timeout, 5 second poll interval
print("Waiting for database to be ready...")
result = client.knowledge_bases.wait_for_database(kb_uuid)
print(f"Database status: {result.database_status}") # "ONLINE"
print("Knowledge base is ready!")

# Alternative: Custom timeout and poll interval
# result = client.knowledge_bases.wait_for_database(
# kb_uuid,
# timeout=900.0, # 15 minutes
# poll_interval=10.0 # Check every 10 seconds
# )

except KnowledgeBaseDatabaseError as e:
# Database entered a failed state (DECOMMISSIONED or UNHEALTHY)
print(f"Database failed: {e}")

except KnowledgeBaseTimeoutError as e:
# Database did not become ready within the timeout period
print(f"Timeout: {e}")


if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,4 @@ known-first-party = ["gradient", "tests"]
[tool.ruff.lint.per-file-ignores]
"bin/**.py" = ["T201", "T203"]
"scripts/**.py" = ["T201", "T203"]
"tests/**.py" = ["T201", "T203"]
"examples/**.py" = ["T201", "T203"]
4 changes: 4 additions & 0 deletions src/gradient/resources/knowledge_bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)
from .knowledge_bases import (
KnowledgeBasesResource,
KnowledgeBaseTimeoutError,
KnowledgeBaseDatabaseError,
AsyncKnowledgeBasesResource,
KnowledgeBasesResourceWithRawResponse,
AsyncKnowledgeBasesResourceWithRawResponse,
Expand All @@ -40,6 +42,8 @@
"AsyncIndexingJobsResourceWithStreamingResponse",
"KnowledgeBasesResource",
"AsyncKnowledgeBasesResource",
"KnowledgeBaseDatabaseError",
"KnowledgeBaseTimeoutError",
"KnowledgeBasesResourceWithRawResponse",
"AsyncKnowledgeBasesResourceWithRawResponse",
"KnowledgeBasesResourceWithStreamingResponse",
Expand Down
183 changes: 182 additions & 1 deletion src/gradient/resources/knowledge_bases/knowledge_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import time
import asyncio
from typing import Iterable

import httpx
Expand Down Expand Up @@ -40,7 +42,24 @@
from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse
from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse

__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"]
__all__ = [
"KnowledgeBasesResource",
"AsyncKnowledgeBasesResource",
"KnowledgeBaseDatabaseError",
"KnowledgeBaseTimeoutError",
]


class KnowledgeBaseDatabaseError(Exception):
"""Raised when a knowledge base database enters a failed state."""

pass


class KnowledgeBaseTimeoutError(Exception):
"""Raised when waiting for a knowledge base database times out."""

pass


class KnowledgeBasesResource(SyncAPIResource):
Expand Down Expand Up @@ -330,6 +349,81 @@ def delete(
cast_to=KnowledgeBaseDeleteResponse,
)

def wait_for_database(
self,
uuid: str,
*,
timeout: float = 600.0,
poll_interval: float = 5.0,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
) -> KnowledgeBaseRetrieveResponse:
"""
Poll the knowledge base until the database status is ONLINE or a failed state is reached.

This helper function repeatedly calls retrieve() to check the database_status field.
It will wait for the database to become ONLINE, or raise an exception if it enters
a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded.

Args:
uuid: The knowledge base UUID to poll

timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes)

poll_interval: Time to wait between polls in seconds (default: 5 seconds)

extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

Returns:
The final KnowledgeBaseRetrieveResponse when the database status is ONLINE

Raises:
KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY)

KnowledgeBaseTimeoutError: If the timeout is exceeded before the database becomes ONLINE
"""
if not uuid:
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")

start_time = time.time()
failed_states = {"DECOMMISSIONED", "UNHEALTHY"}

while True:
elapsed = time.time() - start_time
if elapsed >= timeout:
raise KnowledgeBaseTimeoutError(
f"Timeout waiting for knowledge base database to become ready. "
f"Database did not reach ONLINE status within {timeout} seconds."
)

response = self.retrieve(
uuid,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
)

status = response.database_status

if status == "ONLINE":
return response

if status in failed_states:
raise KnowledgeBaseDatabaseError(f"Knowledge base database entered failed state: {status}")

# Sleep before next poll, but don't exceed timeout
remaining_time = timeout - elapsed
sleep_time = min(poll_interval, remaining_time)
if sleep_time > 0:
time.sleep(sleep_time)


class AsyncKnowledgeBasesResource(AsyncAPIResource):
@cached_property
Expand Down Expand Up @@ -618,6 +712,81 @@ async def delete(
cast_to=KnowledgeBaseDeleteResponse,
)

async def wait_for_database(
self,
uuid: str,
*,
timeout: float = 600.0,
poll_interval: float = 5.0,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
) -> KnowledgeBaseRetrieveResponse:
"""
Poll the knowledge base until the database status is ONLINE or a failed state is reached.

This helper function repeatedly calls retrieve() to check the database_status field.
It will wait for the database to become ONLINE, or raise an exception if it enters
a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded.

Args:
uuid: The knowledge base UUID to poll

timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes)

poll_interval: Time to wait between polls in seconds (default: 5 seconds)

extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

Returns:
The final KnowledgeBaseRetrieveResponse when the database status is ONLINE

Raises:
KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY)

KnowledgeBaseTimeoutError: If the timeout is exceeded before the database becomes ONLINE
"""
if not uuid:
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")

start_time = time.time()
failed_states = {"DECOMMISSIONED", "UNHEALTHY"}

while True:
elapsed = time.time() - start_time
if elapsed >= timeout:
raise KnowledgeBaseTimeoutError(
f"Timeout waiting for knowledge base database to become ready. "
f"Database did not reach ONLINE status within {timeout} seconds."
)

response = await self.retrieve(
uuid,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
)

status = response.database_status

if status == "ONLINE":
return response

if status in failed_states:
raise KnowledgeBaseDatabaseError(f"Knowledge base database entered failed state: {status}")

# Sleep before next poll, but don't exceed timeout
remaining_time = timeout - elapsed
sleep_time = min(poll_interval, remaining_time)
if sleep_time > 0:
await asyncio.sleep(sleep_time)


class KnowledgeBasesResourceWithRawResponse:
def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
Expand All @@ -638,6 +807,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
self.delete = to_raw_response_wrapper(
knowledge_bases.delete,
)
self.wait_for_database = to_raw_response_wrapper(
knowledge_bases.wait_for_database,
)

@cached_property
def data_sources(self) -> DataSourcesResourceWithRawResponse:
Expand Down Expand Up @@ -667,6 +839,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None:
self.delete = async_to_raw_response_wrapper(
knowledge_bases.delete,
)
self.wait_for_database = async_to_raw_response_wrapper(
knowledge_bases.wait_for_database,
)

@cached_property
def data_sources(self) -> AsyncDataSourcesResourceWithRawResponse:
Expand Down Expand Up @@ -696,6 +871,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
self.delete = to_streamed_response_wrapper(
knowledge_bases.delete,
)
self.wait_for_database = to_streamed_response_wrapper(
knowledge_bases.wait_for_database,
)

@cached_property
def data_sources(self) -> DataSourcesResourceWithStreamingResponse:
Expand Down Expand Up @@ -725,6 +903,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None:
self.delete = async_to_streamed_response_wrapper(
knowledge_bases.delete,
)
self.wait_for_database = async_to_streamed_response_wrapper(
knowledge_bases.wait_for_database,
)

@cached_property
def data_sources(self) -> AsyncDataSourcesResourceWithStreamingResponse:
Expand Down
Loading