diff --git a/examples/agent_wait_until_ready.py b/examples/agent_wait_until_ready.py index 3ea7b4a3..df8c8cc6 100644 --- a/examples/agent_wait_until_ready.py +++ b/examples/agent_wait_until_ready.py @@ -24,7 +24,7 @@ if agent_id: print(f"Agent created with ID: {agent_id}") print("Waiting for agent to be ready...") - + try: # Wait for the agent to be deployed and ready # This will poll the agent status every 5 seconds (default) @@ -32,24 +32,24 @@ ready_agent = client.agents.wait_until_ready( agent_id, poll_interval=5.0, # Check every 5 seconds - timeout=300.0, # Wait up to 5 minutes + timeout=300.0, # Wait up to 5 minutes ) - + if ready_agent.agent and ready_agent.agent.deployment: print(f"Agent is ready! Status: {ready_agent.agent.deployment.status}") print(f"Agent URL: {ready_agent.agent.url}") - + # Now you can use the agent # ... - + except AgentDeploymentError as e: print(f"Agent deployment failed: {e}") print(f"Failed status: {e.status}") - + except AgentDeploymentTimeoutError as e: print(f"Agent deployment timed out: {e}") print(f"Agent ID: {e.agent_id}") - + except Exception as e: print(f"Unexpected error: {e}") @@ -60,7 +60,7 @@ async def main() -> None: async_client = AsyncGradient() - + # Create a new agent agent_response = await async_client.agents.create( name="My Async Agent", @@ -68,13 +68,13 @@ async def main() -> None: model_uuid="", region="nyc1", ) - + agent_id = agent_response.agent.uuid if agent_response.agent else None - + if agent_id: print(f"Agent created with ID: {agent_id}") print("Waiting for agent to be ready...") - + try: # Wait for the agent to be deployed and ready (async) ready_agent = await async_client.agents.wait_until_ready( @@ -82,15 +82,15 @@ async def main() -> None: poll_interval=5.0, timeout=300.0, ) - + if ready_agent.agent and ready_agent.agent.deployment: print(f"Agent is ready! Status: {ready_agent.agent.deployment.status}") print(f"Agent URL: {ready_agent.agent.url}") - + except AgentDeploymentError as e: print(f"Agent deployment failed: {e}") print(f"Failed status: {e.status}") - + except AgentDeploymentTimeoutError as e: print(f"Agent deployment timed out: {e}") print(f"Agent ID: {e.agent_id}") diff --git a/examples/knowledge_base_indexing_wait.py b/examples/knowledge_base_indexing_wait.py new file mode 100644 index 00000000..1171fea3 --- /dev/null +++ b/examples/knowledge_base_indexing_wait.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +""" +Example: Waiting for Knowledge Base Indexing Job Completion + +This example demonstrates how to use the wait_for_completion() method +to automatically wait for a knowledge base indexing job to finish, +without needing to write manual polling loops. +""" + +import os + +from gradient import Gradient, IndexingJobError, IndexingJobTimeoutError + + +def main() -> None: + # Initialize the Gradient client + client = Gradient() + + # Example 1: Basic usage - wait for indexing job to complete + print("Example 1: Basic usage") + print("-" * 50) + + # Create an indexing job (replace with your actual knowledge base UUID) + knowledge_base_uuid = os.getenv("KNOWLEDGE_BASE_UUID", "your-kb-uuid-here") + + print(f"Creating indexing job for knowledge base: {knowledge_base_uuid}") + indexing_job = client.knowledge_bases.indexing_jobs.create( + knowledge_base_uuid=knowledge_base_uuid, + ) + + job_uuid = indexing_job.job.uuid if indexing_job.job else None + if not job_uuid: + print("Error: Could not create indexing job") + return + + print(f"Indexing job created with UUID: {job_uuid}") + print("Waiting for indexing job to complete...") + + try: + # Wait for the job to complete (polls every 5 seconds by default) + completed_job = client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + + print("\n✅ Indexing job completed successfully!") + if completed_job.job: + print(f"Phase: {completed_job.job.phase}") + print(f"Total items indexed: {completed_job.job.total_items_indexed}") + print(f"Total items failed: {completed_job.job.total_items_failed}") + print(f"Total datasources: {completed_job.job.total_datasources}") + print(f"Completed datasources: {completed_job.job.completed_datasources}") + + except IndexingJobTimeoutError as e: + print(f"\n⏱️ Timeout: {e}") + except IndexingJobError as e: + print(f"\n❌ Error: {e}") + except Exception as e: + print(f"\n❌ Unexpected error: {e}") + + +def example_with_custom_polling() -> None: + """Example with custom polling interval and timeout""" + print("\n\nExample 2: Custom polling interval and timeout") + print("-" * 50) + + client = Gradient() + knowledge_base_uuid = os.getenv("KNOWLEDGE_BASE_UUID", "your-kb-uuid-here") + + print(f"Creating indexing job for knowledge base: {knowledge_base_uuid}") + indexing_job = client.knowledge_bases.indexing_jobs.create( + knowledge_base_uuid=knowledge_base_uuid, + ) + + job_uuid = indexing_job.job.uuid if indexing_job.job else None + if not job_uuid: + print("Error: Could not create indexing job") + return + + print(f"Indexing job created with UUID: {job_uuid}") + print("Waiting for indexing job to complete (polling every 10 seconds, 5 minute timeout)...") + + try: + # Wait with custom poll interval (10 seconds) and timeout (5 minutes = 300 seconds) + completed_job = client.knowledge_bases.indexing_jobs.wait_for_completion( + job_uuid, + poll_interval=10, # Poll every 10 seconds + timeout=300, # Timeout after 5 minutes + ) + + print("\n✅ Indexing job completed successfully!") + if completed_job.job: + print(f"Phase: {completed_job.job.phase}") + + except IndexingJobTimeoutError: + print("\n⏱️ Job did not complete within 5 minutes") + # You can still check the current status + current_status = client.knowledge_bases.indexing_jobs.retrieve(job_uuid) + if current_status.job: + print(f"Current phase: {current_status.job.phase}") + print( + f"Completed datasources: {current_status.job.completed_datasources}/{current_status.job.total_datasources}" + ) + except IndexingJobError as e: + print(f"\n❌ Job failed: {e}") + + +def example_manual_polling() -> None: + """Example of the old manual polling approach (for comparison)""" + print("\n\nExample 3: Manual polling (old approach)") + print("-" * 50) + + client = Gradient() + knowledge_base_uuid = os.getenv("KNOWLEDGE_BASE_UUID", "your-kb-uuid-here") + + indexing_job = client.knowledge_bases.indexing_jobs.create( + knowledge_base_uuid=knowledge_base_uuid, + ) + + job_uuid = indexing_job.job.uuid if indexing_job.job else None + if not job_uuid: + print("Error: Could not create indexing job") + return + + print(f"Indexing job created with UUID: {job_uuid}") + print("Manual polling (old approach)...") + + import time + + while True: + indexing_job_status = client.knowledge_bases.indexing_jobs.retrieve(job_uuid) + + if indexing_job_status.job and indexing_job_status.job.phase: + phase = indexing_job_status.job.phase + print(f"Current phase: {phase}") + + if phase in ["BATCH_JOB_PHASE_UNKNOWN", "BATCH_JOB_PHASE_PENDING", "BATCH_JOB_PHASE_RUNNING"]: + time.sleep(5) + continue + elif phase == "BATCH_JOB_PHASE_SUCCEEDED": + print("✅ Job completed successfully!") + break + else: + print(f"❌ Job ended with phase: {phase}") + break + + +async def example_async() -> None: + """Example using async/await""" + print("\n\nExample 4: Async usage") + print("-" * 50) + + from gradient import AsyncGradient + + client = AsyncGradient() + knowledge_base_uuid = os.getenv("KNOWLEDGE_BASE_UUID", "your-kb-uuid-here") + + print(f"Creating indexing job for knowledge base: {knowledge_base_uuid}") + indexing_job = await client.knowledge_bases.indexing_jobs.create( + knowledge_base_uuid=knowledge_base_uuid, + ) + + job_uuid = indexing_job.job.uuid if indexing_job.job else None + if not job_uuid: + print("Error: Could not create indexing job") + return + + print(f"Indexing job created with UUID: {job_uuid}") + print("Waiting for indexing job to complete (async)...") + + try: + completed_job = await client.knowledge_bases.indexing_jobs.wait_for_completion( + job_uuid, + poll_interval=5, + timeout=600, # 10 minute timeout + ) + + print("\n✅ Indexing job completed successfully!") + if completed_job.job: + print(f"Phase: {completed_job.job.phase}") + + except IndexingJobTimeoutError as e: + print(f"\n⏱️ Timeout: {e}") + except IndexingJobError as e: + print(f"\n❌ Error: {e}") + finally: + await client.close() + + +if __name__ == "__main__": + # Run the basic example + main() + + # Uncomment to run other examples: + # example_with_custom_polling() + # example_manual_polling() + + # For async example, you would need to run: + # import asyncio + # asyncio.run(example_async()) diff --git a/src/gradient/__init__.py b/src/gradient/__init__.py index a67cd2a7..864f1484 100644 --- a/src/gradient/__init__.py +++ b/src/gradient/__init__.py @@ -29,12 +29,16 @@ RateLimitError, APITimeoutError, BadRequestError, + IndexingJobError, APIConnectionError, AuthenticationError, InternalServerError, + AgentDeploymentError, PermissionDeniedError, + IndexingJobTimeoutError, UnprocessableEntityError, APIResponseValidationError, + AgentDeploymentTimeoutError, ) from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging @@ -65,6 +69,10 @@ "UnprocessableEntityError", "RateLimitError", "InternalServerError", + "IndexingJobError", + "IndexingJobTimeoutError", + "AgentDeploymentError", + "AgentDeploymentTimeoutError", "Timeout", "RequestOptions", "Client", diff --git a/src/gradient/_exceptions.py b/src/gradient/_exceptions.py index 0ced4aba..f0c6671d 100644 --- a/src/gradient/_exceptions.py +++ b/src/gradient/_exceptions.py @@ -15,6 +15,8 @@ "UnprocessableEntityError", "RateLimitError", "InternalServerError", + "IndexingJobError", + "IndexingJobTimeoutError", "AgentDeploymentError", "AgentDeploymentTimeoutError", ] @@ -110,6 +112,32 @@ class InternalServerError(APIStatusError): pass +class IndexingJobError(GradientError): + """Raised when an indexing job fails, encounters an error, or is cancelled.""" + + uuid: str + phase: str + + def __init__(self, message: str, *, uuid: str, phase: str) -> None: + super().__init__(message) + self.uuid = uuid + self.phase = phase + + +class IndexingJobTimeoutError(GradientError): + """Raised when polling for an indexing job times out.""" + + uuid: str + phase: str + timeout: float + + def __init__(self, message: str, *, uuid: str, phase: str, timeout: float) -> None: + super().__init__(message) + self.uuid = uuid + self.phase = phase + self.timeout = timeout + + class AgentDeploymentError(GradientError): """Raised when an agent deployment fails.""" diff --git a/src/gradient/resources/knowledge_bases/indexing_jobs.py b/src/gradient/resources/knowledge_bases/indexing_jobs.py index 95898c2a..647ab308 100644 --- a/src/gradient/resources/knowledge_bases/indexing_jobs.py +++ b/src/gradient/resources/knowledge_bases/indexing_jobs.py @@ -2,6 +2,9 @@ from __future__ import annotations +import time +import asyncio + import httpx from ..._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given @@ -14,6 +17,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..._exceptions import IndexingJobError, IndexingJobTimeoutError from ..._base_client import make_request_options from ...types.knowledge_bases import ( indexing_job_list_params, @@ -259,6 +263,110 @@ def update_cancel( cast_to=IndexingJobUpdateCancelResponse, ) + def wait_for_completion( + self, + uuid: str, + *, + poll_interval: float = 5, + timeout: float | None = None, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + request_timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> IndexingJobRetrieveResponse: + """ + Wait for an indexing job to complete by polling its status. + + This method polls the indexing job status at regular intervals until it reaches + a terminal state (succeeded, failed, error, or cancelled). It raises an exception + if the job fails or times out. + + Args: + uuid: The UUID of the indexing job to wait for. + + poll_interval: Time in seconds between status checks (default: 5 seconds). + + timeout: Maximum time in seconds to wait for completion. If None, waits indefinitely. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + request_timeout: Override the client-level default timeout for this request, in seconds + + Returns: + The final IndexingJobRetrieveResponse when the job completes successfully. + + Raises: + IndexingJobTimeoutError: If the job doesn't complete within the specified timeout. + IndexingJobError: If the job fails, errors, or is cancelled. + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + + while True: + response = self.retrieve( + uuid, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=request_timeout, + ) + + # Check if job is in a terminal state + if response.job and response.job.phase: + phase = response.job.phase + + # Success state + if phase == "BATCH_JOB_PHASE_SUCCEEDED": + return response + + # Failure states + if phase == "BATCH_JOB_PHASE_FAILED": + raise IndexingJobError( + f"Indexing job {uuid} failed. " + f"Total items indexed: {response.job.total_items_indexed}, " + f"Total items failed: {response.job.total_items_failed}", + uuid=uuid, + phase=phase, + ) + + if phase == "BATCH_JOB_PHASE_ERROR": + raise IndexingJobError( + f"Indexing job {uuid} encountered an error", + uuid=uuid, + phase=phase, + ) + + if phase == "BATCH_JOB_PHASE_CANCELLED": + raise IndexingJobError( + f"Indexing job {uuid} was cancelled", + uuid=uuid, + phase=phase, + ) + + # Still in progress (UNKNOWN, PENDING, or RUNNING) + # Check timeout + if timeout is not None: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise IndexingJobTimeoutError( + f"Indexing job {uuid} did not complete within {timeout} seconds. " + f"Current phase: {phase}", + uuid=uuid, + phase=phase, + timeout=timeout, + ) + + # Wait before next poll + time.sleep(poll_interval) + class AsyncIndexingJobsResource(AsyncAPIResource): @cached_property @@ -490,6 +598,110 @@ async def update_cancel( cast_to=IndexingJobUpdateCancelResponse, ) + async def wait_for_completion( + self, + uuid: str, + *, + poll_interval: float = 5, + timeout: float | None = None, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + request_timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> IndexingJobRetrieveResponse: + """ + Wait for an indexing job to complete by polling its status. + + This method polls the indexing job status at regular intervals until it reaches + a terminal state (succeeded, failed, error, or cancelled). It raises an exception + if the job fails or times out. + + Args: + uuid: The UUID of the indexing job to wait for. + + poll_interval: Time in seconds between status checks (default: 5 seconds). + + timeout: Maximum time in seconds to wait for completion. If None, waits indefinitely. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + request_timeout: Override the client-level default timeout for this request, in seconds + + Returns: + The final IndexingJobRetrieveResponse when the job completes successfully. + + Raises: + IndexingJobTimeoutError: If the job doesn't complete within the specified timeout. + IndexingJobError: If the job fails, errors, or is cancelled. + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + + while True: + response = await self.retrieve( + uuid, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=request_timeout, + ) + + # Check if job is in a terminal state + if response.job and response.job.phase: + phase = response.job.phase + + # Success state + if phase == "BATCH_JOB_PHASE_SUCCEEDED": + return response + + # Failure states + if phase == "BATCH_JOB_PHASE_FAILED": + raise IndexingJobError( + f"Indexing job {uuid} failed. " + f"Total items indexed: {response.job.total_items_indexed}, " + f"Total items failed: {response.job.total_items_failed}", + uuid=uuid, + phase=phase, + ) + + if phase == "BATCH_JOB_PHASE_ERROR": + raise IndexingJobError( + f"Indexing job {uuid} encountered an error", + uuid=uuid, + phase=phase, + ) + + if phase == "BATCH_JOB_PHASE_CANCELLED": + raise IndexingJobError( + f"Indexing job {uuid} was cancelled", + uuid=uuid, + phase=phase, + ) + + # Still in progress (UNKNOWN, PENDING, or RUNNING) + # Check timeout + if timeout is not None: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise IndexingJobTimeoutError( + f"Indexing job {uuid} did not complete within {timeout} seconds. " + f"Current phase: {phase}", + uuid=uuid, + phase=phase, + timeout=timeout, + ) + + # Wait before next poll + await asyncio.sleep(poll_interval) + class IndexingJobsResourceWithRawResponse: def __init__(self, indexing_jobs: IndexingJobsResource) -> None: @@ -510,6 +722,9 @@ def __init__(self, indexing_jobs: IndexingJobsResource) -> None: self.update_cancel = to_raw_response_wrapper( indexing_jobs.update_cancel, ) + self.wait_for_completion = to_raw_response_wrapper( + indexing_jobs.wait_for_completion, + ) class AsyncIndexingJobsResourceWithRawResponse: @@ -531,6 +746,9 @@ def __init__(self, indexing_jobs: AsyncIndexingJobsResource) -> None: self.update_cancel = async_to_raw_response_wrapper( indexing_jobs.update_cancel, ) + self.wait_for_completion = async_to_raw_response_wrapper( + indexing_jobs.wait_for_completion, + ) class IndexingJobsResourceWithStreamingResponse: @@ -552,6 +770,9 @@ def __init__(self, indexing_jobs: IndexingJobsResource) -> None: self.update_cancel = to_streamed_response_wrapper( indexing_jobs.update_cancel, ) + self.wait_for_completion = to_streamed_response_wrapper( + indexing_jobs.wait_for_completion, + ) class AsyncIndexingJobsResourceWithStreamingResponse: @@ -573,3 +794,6 @@ def __init__(self, indexing_jobs: AsyncIndexingJobsResource) -> None: self.update_cancel = async_to_streamed_response_wrapper( indexing_jobs.update_cancel, ) + self.wait_for_completion = async_to_streamed_response_wrapper( + indexing_jobs.wait_for_completion, + ) diff --git a/tests/api_resources/knowledge_bases/test_indexing_jobs.py b/tests/api_resources/knowledge_bases/test_indexing_jobs.py index 3dffaa69..88c551c8 100644 --- a/tests/api_resources/knowledge_bases/test_indexing_jobs.py +++ b/tests/api_resources/knowledge_bases/test_indexing_jobs.py @@ -5,9 +5,10 @@ import os from typing import Any, cast +import httpx import pytest -from gradient import Gradient, AsyncGradient +from gradient import Gradient, AsyncGradient, IndexingJobError, IndexingJobTimeoutError from tests.utils import assert_matches_type from gradient.types.knowledge_bases import ( IndexingJobListResponse, @@ -232,6 +233,125 @@ def test_path_params_update_cancel(self, client: Gradient) -> None: path_uuid="", ) + @parametrize + def test_wait_for_completion_raises_indexing_job_error_on_failed(self, client: Gradient, respx_mock: Any) -> None: + """Test that IndexingJobError is raised when job phase is FAILED""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_FAILED", + "total_items_indexed": "10", + "total_items_failed": "5", + } + }, + ) + ) + + with pytest.raises(IndexingJobError) as exc_info: + client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + + assert exc_info.value.uuid == job_uuid + assert exc_info.value.phase == "BATCH_JOB_PHASE_FAILED" + assert "failed" in str(exc_info.value).lower() + + @parametrize + def test_wait_for_completion_raises_indexing_job_error_on_error(self, client: Gradient, respx_mock: Any) -> None: + """Test that IndexingJobError is raised when job phase is ERROR""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_ERROR", + } + }, + ) + ) + + with pytest.raises(IndexingJobError) as exc_info: + client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + + assert exc_info.value.uuid == job_uuid + assert exc_info.value.phase == "BATCH_JOB_PHASE_ERROR" + assert "error" in str(exc_info.value).lower() + + @parametrize + def test_wait_for_completion_raises_indexing_job_error_on_cancelled( + self, client: Gradient, respx_mock: Any + ) -> None: + """Test that IndexingJobError is raised when job phase is CANCELLED""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_CANCELLED", + } + }, + ) + ) + + with pytest.raises(IndexingJobError) as exc_info: + client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + + assert exc_info.value.uuid == job_uuid + assert exc_info.value.phase == "BATCH_JOB_PHASE_CANCELLED" + assert "cancelled" in str(exc_info.value).lower() + + @parametrize + def test_wait_for_completion_raises_timeout_error(self, client: Gradient, respx_mock: Any) -> None: + """Test that IndexingJobTimeoutError is raised on timeout""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_RUNNING", + } + }, + ) + ) + + with pytest.raises(IndexingJobTimeoutError) as exc_info: + client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid, poll_interval=0.1, timeout=0.2) + + assert exc_info.value.uuid == job_uuid + assert exc_info.value.phase == "BATCH_JOB_PHASE_RUNNING" + assert exc_info.value.timeout == 0.2 + + @parametrize + def test_wait_for_completion_succeeds(self, client: Gradient, respx_mock: Any) -> None: + """Test that wait_for_completion returns successfully when job succeeds""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_SUCCEEDED", + "total_items_indexed": "100", + "total_items_failed": "0", + } + }, + ) + ) + + result = client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + assert_matches_type(IndexingJobRetrieveResponse, result, path=["response"]) + assert result.job is not None + assert result.job.phase == "BATCH_JOB_PHASE_SUCCEEDED" + class TestAsyncIndexingJobs: parametrize = pytest.mark.parametrize( @@ -446,3 +566,128 @@ async def test_path_params_update_cancel(self, async_client: AsyncGradient) -> N await async_client.knowledge_bases.indexing_jobs.with_raw_response.update_cancel( path_uuid="", ) + + @parametrize + async def test_wait_for_completion_raises_indexing_job_error_on_failed( + self, async_client: AsyncGradient, respx_mock: Any + ) -> None: + """Test that IndexingJobError is raised when job phase is FAILED""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_FAILED", + "total_items_indexed": "10", + "total_items_failed": "5", + } + }, + ) + ) + + with pytest.raises(IndexingJobError) as exc_info: + await async_client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + + assert exc_info.value.uuid == job_uuid + assert exc_info.value.phase == "BATCH_JOB_PHASE_FAILED" + assert "failed" in str(exc_info.value).lower() + + @parametrize + async def test_wait_for_completion_raises_indexing_job_error_on_error( + self, async_client: AsyncGradient, respx_mock: Any + ) -> None: + """Test that IndexingJobError is raised when job phase is ERROR""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_ERROR", + } + }, + ) + ) + + with pytest.raises(IndexingJobError) as exc_info: + await async_client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + + assert exc_info.value.uuid == job_uuid + assert exc_info.value.phase == "BATCH_JOB_PHASE_ERROR" + assert "error" in str(exc_info.value).lower() + + @parametrize + async def test_wait_for_completion_raises_indexing_job_error_on_cancelled( + self, async_client: AsyncGradient, respx_mock: Any + ) -> None: + """Test that IndexingJobError is raised when job phase is CANCELLED""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_CANCELLED", + } + }, + ) + ) + + with pytest.raises(IndexingJobError) as exc_info: + await async_client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + + assert exc_info.value.uuid == job_uuid + assert exc_info.value.phase == "BATCH_JOB_PHASE_CANCELLED" + assert "cancelled" in str(exc_info.value).lower() + + @parametrize + async def test_wait_for_completion_raises_timeout_error(self, async_client: AsyncGradient, respx_mock: Any) -> None: + """Test that IndexingJobTimeoutError is raised on timeout""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_RUNNING", + } + }, + ) + ) + + with pytest.raises(IndexingJobTimeoutError) as exc_info: + await async_client.knowledge_bases.indexing_jobs.wait_for_completion( + job_uuid, poll_interval=0.1, timeout=0.2 + ) + + assert exc_info.value.uuid == job_uuid + assert exc_info.value.phase == "BATCH_JOB_PHASE_RUNNING" + assert exc_info.value.timeout == 0.2 + + @parametrize + async def test_wait_for_completion_succeeds(self, async_client: AsyncGradient, respx_mock: Any) -> None: + """Test that wait_for_completion returns successfully when job succeeds""" + job_uuid = "test-job-uuid" + respx_mock.get(f"{base_url}/v2/gen-ai/indexing_jobs/{job_uuid}").mock( + return_value=httpx.Response( + 200, + json={ + "job": { + "uuid": job_uuid, + "phase": "BATCH_JOB_PHASE_SUCCEEDED", + "total_items_indexed": "100", + "total_items_failed": "0", + } + }, + ) + ) + + result = await async_client.knowledge_bases.indexing_jobs.wait_for_completion(job_uuid) + assert_matches_type(IndexingJobRetrieveResponse, result, path=["response"]) + assert result.job is not None + assert result.job.phase == "BATCH_JOB_PHASE_SUCCEEDED" diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py index 5777c3ea..1ba3e093 100644 --- a/tests/api_resources/test_agents.py +++ b/tests/api_resources/test_agents.py @@ -368,9 +368,10 @@ def test_path_params_update_status(self, client: Gradient) -> None: def test_method_wait_until_ready(self, client: Gradient, respx_mock: Any) -> None: """Test successful wait_until_ready when agent becomes ready.""" agent_uuid = "test-agent-id" - + # Create side effect that returns different responses call_count = [0] + def get_response(_: httpx.Request) -> httpx.Response: call_count[0] += 1 if call_count[0] == 1: @@ -395,9 +396,9 @@ def get_response(_: httpx.Request) -> httpx.Response: } }, ) - + respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock(side_effect=get_response) - + agent = client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=10.0) assert_matches_type(AgentRetrieveResponse, agent, path=["response"]) assert agent.agent is not None @@ -408,9 +409,9 @@ def get_response(_: httpx.Request) -> httpx.Response: def test_wait_until_ready_timeout(self, client: Gradient, respx_mock: Any) -> None: """Test that wait_until_ready raises timeout error.""" from gradient._exceptions import AgentDeploymentTimeoutError - + agent_uuid = "test-agent-id" - + # Mock always returns deploying respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock( return_value=httpx.Response( @@ -423,10 +424,10 @@ def test_wait_until_ready_timeout(self, client: Gradient, respx_mock: Any) -> No }, ) ) - + with pytest.raises(AgentDeploymentTimeoutError) as exc_info: client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=0.5) - + assert "did not reach STATUS_RUNNING within" in str(exc_info.value) assert exc_info.value.agent_id == agent_uuid @@ -434,9 +435,9 @@ def test_wait_until_ready_timeout(self, client: Gradient, respx_mock: Any) -> No def test_wait_until_ready_deployment_failed(self, client: Gradient, respx_mock: Any) -> None: """Test that wait_until_ready raises error on deployment failure.""" from gradient._exceptions import AgentDeploymentError - + agent_uuid = "test-agent-id" - + # Mock returns failed status respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock( return_value=httpx.Response( @@ -449,10 +450,10 @@ def test_wait_until_ready_deployment_failed(self, client: Gradient, respx_mock: }, ) ) - + with pytest.raises(AgentDeploymentError) as exc_info: client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=10.0) - + assert "deployment failed with status: STATUS_FAILED" in str(exc_info.value) assert exc_info.value.status == "STATUS_FAILED" @@ -810,9 +811,10 @@ async def test_path_params_update_status(self, async_client: AsyncGradient) -> N async def test_method_wait_until_ready(self, async_client: AsyncGradient, respx_mock: Any) -> None: """Test successful async wait_until_ready when agent becomes ready.""" agent_uuid = "test-agent-id" - + # Create side effect that returns different responses call_count = [0] + def get_response(_: httpx.Request) -> httpx.Response: call_count[0] += 1 if call_count[0] == 1: @@ -837,9 +839,9 @@ def get_response(_: httpx.Request) -> httpx.Response: } }, ) - + respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock(side_effect=get_response) - + agent = await async_client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=10.0) assert_matches_type(AgentRetrieveResponse, agent, path=["response"]) assert agent.agent is not None @@ -850,9 +852,9 @@ def get_response(_: httpx.Request) -> httpx.Response: async def test_wait_until_ready_timeout(self, async_client: AsyncGradient, respx_mock: Any) -> None: """Test that async wait_until_ready raises timeout error.""" from gradient._exceptions import AgentDeploymentTimeoutError - + agent_uuid = "test-agent-id" - + # Mock always returns deploying respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock( return_value=httpx.Response( @@ -865,10 +867,10 @@ async def test_wait_until_ready_timeout(self, async_client: AsyncGradient, respx }, ) ) - + with pytest.raises(AgentDeploymentTimeoutError) as exc_info: await async_client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=0.5) - + assert "did not reach STATUS_RUNNING within" in str(exc_info.value) assert exc_info.value.agent_id == agent_uuid @@ -876,9 +878,9 @@ async def test_wait_until_ready_timeout(self, async_client: AsyncGradient, respx async def test_wait_until_ready_deployment_failed(self, async_client: AsyncGradient, respx_mock: Any) -> None: """Test that async wait_until_ready raises error on deployment failure.""" from gradient._exceptions import AgentDeploymentError - + agent_uuid = "test-agent-id" - + # Mock returns failed status respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock( return_value=httpx.Response( @@ -891,9 +893,9 @@ async def test_wait_until_ready_deployment_failed(self, async_client: AsyncGradi }, ) ) - + with pytest.raises(AgentDeploymentError) as exc_info: await async_client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=10.0) - + assert "deployment failed with status: STATUS_FAILED" in str(exc_info.value) assert exc_info.value.status == "STATUS_FAILED"