diff --git a/docs/my-website/docs/vertex_batch_passthrough.md b/docs/my-website/docs/vertex_batch_passthrough.md new file mode 100644 index 000000000000..3203d7d792a7 --- /dev/null +++ b/docs/my-website/docs/vertex_batch_passthrough.md @@ -0,0 +1,160 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# /batchPredictionJobs + +LiteLLM supports Vertex AI batch prediction jobs through passthrough endpoints, allowing you to create and manage batch jobs directly through the proxy server. + +## Features + +- **Batch Job Creation**: Create batch prediction jobs using Vertex AI models +- **Cost Tracking**: Automatic cost calculation and usage tracking for batch operations +- **Status Monitoring**: Track job status and retrieve results +- **Model Support**: Works with all supported Vertex AI models (Gemini, Text Embedding) + +## Cost Tracking Support + +| Feature | Supported | Notes | +|---------|-----------|-------| +| Cost Tracking | ✅ | Automatic cost calculation for batch operations | +| Usage Monitoring | ✅ | Track token usage and costs across batch jobs | +| Logging | ✅ | Supported | + +## Quick Start + +1. **Configure your model** in the proxy configuration: + +```yaml +model_list: + - model_name: gemini-1.5-flash + litellm_params: + model: vertex_ai/gemini-1.5-flash + vertex_project: your-project-id + vertex_location: us-central1 + vertex_credentials: path/to/service-account.json +``` + +2. **Create a batch job**: + +```bash +curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "displayName": "my-batch-job", + "model": "projects/your-project/locations/us-central1/publishers/google/models/gemini-1.5-flash", + "inputConfig": { + "gcsSource": { + "uris": ["gs://my-bucket/input.jsonl"] + }, + "instancesFormat": "jsonl" + }, + "outputConfig": { + "gcsDestination": { + "outputUriPrefix": "gs://my-bucket/output/" + }, + "predictionsFormat": "jsonl" + } + }' +``` + +3. **Monitor job status**: + +```bash +curl -X GET "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs/job-id" \ + -H "Authorization: Bearer your-api-key" +``` + +## Model Configuration + +When configuring models for batch operations, use these naming conventions: + +- **`model_name`**: Base model name (e.g., `gemini-1.5-flash`) +- **`model`**: Full LiteLLM identifier (e.g., `vertex_ai/gemini-1.5-flash`) + +## Supported Models + +- `gemini-1.5-flash` / `vertex_ai/gemini-1.5-flash` +- `gemini-1.5-pro` / `vertex_ai/gemini-1.5-pro` +- `gemini-2.0-flash` / `vertex_ai/gemini-2.0-flash` +- `gemini-2.0-pro` / `vertex_ai/gemini-2.0-pro` + +## Advanced Usage + +### Batch Job with Custom Parameters + +```bash +curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "displayName": "advanced-batch-job", + "model": "projects/your-project/locations/us-central1/publishers/google/models/gemini-1.5-pro", + "inputConfig": { + "gcsSource": { + "uris": ["gs://my-bucket/advanced-input.jsonl"] + }, + "instancesFormat": "jsonl" + }, + "outputConfig": { + "gcsDestination": { + "outputUriPrefix": "gs://my-bucket/advanced-output/" + }, + "predictionsFormat": "jsonl" + }, + "labels": { + "environment": "production", + "team": "ml-engineering" + } + }' +``` + +### List All Batch Jobs + +```bash +curl -X GET "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \ + -H "Authorization: Bearer your-api-key" +``` + +### Cancel a Batch Job + +```bash +curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs/job-id:cancel" \ + -H "Authorization: Bearer your-api-key" +``` + +## Cost Tracking Details + +LiteLLM provides comprehensive cost tracking for Vertex AI batch operations: + +- **Token Usage**: Tracks input and output tokens for each batch request +- **Cost Calculation**: Automatically calculates costs based on current Vertex AI pricing +- **Usage Aggregation**: Aggregates costs across all requests in a batch job +- **Real-time Monitoring**: Monitor costs as batch jobs progress + +The cost tracking works seamlessly with the `generateContent` API and provides detailed insights into your batch processing expenses. + +## Error Handling + +Common error scenarios and their solutions: + +| Error | Description | Solution | +|-------|-------------|----------| +| `INVALID_ARGUMENT` | Invalid model or configuration | Verify model name and project settings | +| `PERMISSION_DENIED` | Insufficient permissions | Check Vertex AI IAM roles | +| `RESOURCE_EXHAUSTED` | Quota exceeded | Check Vertex AI quotas and limits | +| `NOT_FOUND` | Job or resource not found | Verify job ID and project configuration | + +## Best Practices + +1. **Use appropriate batch sizes**: Balance between processing efficiency and resource usage +2. **Monitor job status**: Regularly check job status to handle failures promptly +3. **Set up alerts**: Configure monitoring for job completion and failures +4. **Optimize costs**: Use cost tracking to identify optimization opportunities +5. **Test with small batches**: Validate your setup with small test batches first + +## Related Documentation + +- [Vertex AI Provider Documentation](./vertex.md) +- [General Batches API Documentation](../batches.md) +- [Cost Tracking and Monitoring](../observability/telemetry.md) diff --git a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py index 4b1bb024ac6a..d4ee4042b1ae 100644 --- a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py +++ b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py @@ -57,7 +57,6 @@ async def check_batch_cost(self): "file_purpose": "batch", } ) - completed_jobs = [] for job in jobs: @@ -139,7 +138,7 @@ async def check_batch_cost(self): custom_llm_provider = deployment_info.litellm_params.custom_llm_provider litellm_model_name = deployment_info.litellm_params.model - _, llm_provider, _, _ = get_llm_provider( + model_name, llm_provider, _, _ = get_llm_provider( model=litellm_model_name, custom_llm_provider=custom_llm_provider, ) @@ -148,9 +147,9 @@ async def check_batch_cost(self): await calculate_batch_cost_and_usage( file_content_dictionary=file_content_as_dict, custom_llm_provider=llm_provider, # type: ignore + model_name=model_name, ) ) - logging_obj = LiteLLMLogging( model=batch_models[0], messages=[{"role": "user", "content": ""}], diff --git a/litellm/batches/batch_utils.py b/litellm/batches/batch_utils.py index 814851e560b5..027c0b219c88 100644 --- a/litellm/batches/batch_utils.py +++ b/litellm/batches/batch_utils.py @@ -1,5 +1,5 @@ import json -from typing import Any, List, Literal, Tuple +from typing import Any, List, Literal, Tuple, Optional import litellm from litellm._logging import verbose_logger @@ -10,21 +10,22 @@ async def calculate_batch_cost_and_usage( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai"], + model_name: Optional[str] = None, ) -> Tuple[float, Usage, List[str]]: """ Calculate the cost and usage of a batch """ - # Calculate costs and usage batch_cost = _batch_cost_calculator( custom_llm_provider=custom_llm_provider, file_content_dictionary=file_content_dictionary, + model_name=model_name, ) batch_usage = _get_batch_job_total_usage_from_file_content( file_content_dictionary=file_content_dictionary, custom_llm_provider=custom_llm_provider, + model_name=model_name, ) - - batch_models = _get_batch_models_from_file_content(file_content_dictionary) + batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name) return batch_cost, batch_usage, batch_models @@ -32,6 +33,7 @@ async def calculate_batch_cost_and_usage( async def _handle_completed_batch( batch: Batch, custom_llm_provider: Literal["openai", "azure", "vertex_ai"], + model_name: Optional[str] = None, ) -> Tuple[float, Usage, List[str]]: """Helper function to process a completed batch and handle logging""" # Get batch results @@ -43,23 +45,28 @@ async def _handle_completed_batch( batch_cost = _batch_cost_calculator( custom_llm_provider=custom_llm_provider, file_content_dictionary=file_content_dictionary, + model_name=model_name, ) batch_usage = _get_batch_job_total_usage_from_file_content( file_content_dictionary=file_content_dictionary, custom_llm_provider=custom_llm_provider, + model_name=model_name, ) - batch_models = _get_batch_models_from_file_content(file_content_dictionary) + batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name) return batch_cost, batch_usage, batch_models def _get_batch_models_from_file_content( file_content_dictionary: List[dict], + model_name: Optional[str] = None, ) -> List[str]: """ Get the models from the file content """ + if model_name: + return [model_name] batch_models = [] for _item in file_content_dictionary: if _batch_response_was_successful(_item): @@ -73,12 +80,18 @@ def _get_batch_models_from_file_content( def _batch_cost_calculator( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", + model_name: Optional[str] = None, ) -> float: """ Calculate the cost of a batch based on the output file id """ - if custom_llm_provider == "vertex_ai": - raise ValueError("Vertex AI does not support file content retrieval") + # Handle Vertex AI with specialized method + if custom_llm_provider == "vertex_ai" and model_name: + batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name) + verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost) + return batch_cost + + # For other providers, use the existing logic total_cost = _get_batch_job_cost_from_file_content( file_content_dictionary=file_content_dictionary, custom_llm_provider=custom_llm_provider, @@ -87,6 +100,87 @@ def _batch_cost_calculator( return total_cost +def calculate_vertex_ai_batch_cost_and_usage( + vertex_ai_batch_responses: List[dict], + model_name: Optional[str] = None, +) -> Tuple[float, Usage]: + """ + Calculate both cost and usage from Vertex AI batch responses + """ + total_cost = 0.0 + total_tokens = 0 + prompt_tokens = 0 + completion_tokens = 0 + + for response in vertex_ai_batch_responses: + if response.get("status") == "JOB_STATE_SUCCEEDED": # Check if response was successful + # Transform Vertex AI response to OpenAI format if needed + from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig + from litellm import ModelResponse + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.types.utils import CallTypes + from litellm._uuid import uuid + import httpx + import time + + # Create required arguments for the transformation method + model_response = ModelResponse() + + # Ensure model_name is not None + actual_model_name = model_name or "gemini-2.5-flash" + + # Create a real LiteLLM logging object + logging_obj = Logging( + model=actual_model_name, + messages=[{"role": "user", "content": "batch_request"}], + stream=False, + call_type=CallTypes.aretrieve_batch, + start_time=time.time(), + litellm_call_id="batch_" + str(uuid.uuid4()), + function_id="batch_processing", + litellm_trace_id=str(uuid.uuid4()), + kwargs={"optional_params": {}} + ) + + # Add the optional_params attribute that the Vertex AI transformation expects + logging_obj.optional_params = {} + raw_response = httpx.Response(200) # Mock response object + + openai_format_response = VertexGeminiConfig()._transform_google_generate_content_to_openai_model_response( + completion_response=response["response"], + model_response=model_response, + model=actual_model_name, + logging_obj=logging_obj, + raw_response=raw_response, + ) + + # Calculate cost using existing function + cost = litellm.completion_cost( + completion_response=openai_format_response, + custom_llm_provider="vertex_ai", + call_type=CallTypes.aretrieve_batch.value, + ) + total_cost += cost + + # Extract usage from the transformed response + if hasattr(openai_format_response, 'usage') and openai_format_response.usage: + usage = openai_format_response.usage + else: + # Fallback: create usage from response dict + response_dict = openai_format_response.dict() if hasattr(openai_format_response, 'dict') else {} + usage = _get_batch_job_usage_from_response_body(response_dict) + + total_tokens += usage.total_tokens + prompt_tokens += usage.prompt_tokens + completion_tokens += usage.completion_tokens + + return total_cost, Usage( + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + async def _get_batch_output_file_content_as_dictionary( batch: Batch, custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", @@ -157,10 +251,17 @@ def _get_batch_job_cost_from_file_content( def _get_batch_job_total_usage_from_file_content( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", + model_name: Optional[str] = None, ) -> Usage: """ Get the tokens of a batch job from the file content """ + # Handle Vertex AI with specialized method + if custom_llm_provider == "vertex_ai" and model_name: + _, batch_usage = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name) + return batch_usage + + # For other providers, use the existing logic total_tokens: int = 0 prompt_tokens: int = 0 completion_tokens: int = 0 diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index 5b22b2746c9a..c7695704330a 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -18,14 +18,19 @@ ImageResponse, ModelResponse, TextCompletionResponse, + Choices, ) +from litellm.types.utils import SpecialEnums if TYPE_CHECKING: from ..success_handler import PassThroughEndpointLogging - from ..types import EndpointType + from litellm.types.utils import LiteLLMBatch else: PassThroughEndpointLogging = Any - EndpointType = Any + LiteLLMBatch = Any + +# Define EndpointType locally to avoid import issues +EndpointType = Any class VertexPassthroughLoggingHandler: @@ -204,6 +209,17 @@ def vertex_passthrough_handler( "result": litellm_prediction_response, "kwargs": kwargs, } + elif "batchPredictionJobs" in url_route: + return VertexPassthroughLoggingHandler.batch_prediction_jobs_handler( + httpx_response=httpx_response, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) else: return { "result": None, @@ -324,6 +340,38 @@ def extract_model_from_url(url: str) -> str: return match.group(1) return "unknown" + @staticmethod + def extract_model_name_from_vertex_path(vertex_model_path: str) -> str: + """ + Extract the actual model name from a Vertex AI model path. + + Examples: + - publishers/google/models/gemini-2.5-flash -> gemini-2.5-flash + - projects/PROJECT_ID/locations/LOCATION/models/MODEL_ID -> MODEL_ID + + Args: + vertex_model_path: The full Vertex AI model path + + Returns: + The extracted model name for use with LiteLLM + """ + # Handle publishers/google/models/ format + if "publishers/" in vertex_model_path and "models/" in vertex_model_path: + # Extract everything after the last models/ + parts = vertex_model_path.split("models/") + if len(parts) > 1: + return parts[-1] + + # Handle projects/PROJECT_ID/locations/LOCATION/models/MODEL_ID format + elif "projects/" in vertex_model_path and "models/" in vertex_model_path: + # Extract everything after the last models/ + parts = vertex_model_path.split("models/") + if len(parts) > 1: + return parts[-1] + + # If no recognized pattern, return the original path + return vertex_model_path + @staticmethod def _get_vertex_publisher_or_api_spec_from_url(url: str) -> Optional[str]: # Check for specific Vertex AI partner publishers @@ -388,7 +436,7 @@ def _create_vertex_response_logging_payload_for_generate_content( end_time: datetime, logging_obj: LiteLLMLoggingObj, custom_llm_provider: str, - ): + ) -> dict: """ Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming) @@ -412,3 +460,255 @@ def _create_vertex_response_logging_payload_for_generate_content( logging_obj.model_call_details["model"] = logging_obj.model logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider return kwargs + + @staticmethod + def batch_prediction_jobs_handler( + httpx_response: httpx.Response, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Handle batch prediction jobs passthrough logging. + Creates a managed object for cost tracking when batch job is successfully created. + """ + from litellm.llms.vertex_ai.batches.transformation import VertexAIBatchTransformation + from litellm._uuid import uuid + import base64 + + try: + _json_response = httpx_response.json() + + # Only handle successful batch job creation (POST requests) + if httpx_response.status_code == 200 and "name" in _json_response: + # Transform Vertex AI response to LiteLLM batch format + litellm_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + + # Extract batch ID and model from the response + batch_id = VertexAIBatchTransformation._get_batch_id_from_vertex_ai_batch_response(_json_response) + model_name = _json_response.get("model", "unknown") + + # Create unified object ID for tracking + # Format: base64(litellm_proxy;model_id:{};llm_batch_id:{}) + actual_model_id = VertexPassthroughLoggingHandler.get_actual_model_id_from_router(model_name) + + unified_id_string = SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format(actual_model_id, batch_id) + unified_object_id = base64.urlsafe_b64encode(unified_id_string.encode()).decode().rstrip("=") + + # Store the managed object for cost tracking + # This will be picked up by check_batch_cost polling mechanism + VertexPassthroughLoggingHandler._store_batch_managed_object( + unified_object_id=unified_object_id, + batch_object=litellm_batch_response, + model_object_id=batch_id, + logging_obj=logging_obj, + **kwargs, + ) + + # Create a batch job response for logging + litellm_model_response = ModelResponse() + litellm_model_response.id = str(uuid.uuid4()) + litellm_model_response.model = model_name + litellm_model_response.object = "batch_prediction_job" + litellm_model_response.created = int(start_time.timestamp()) + + # Add batch-specific metadata to indicate this is a pending batch job + litellm_model_response.choices = [Choices( + finish_reason="batch_pending", + index=0, + message={ + "role": "assistant", + "content": f"Batch prediction job {batch_id} created and is pending. Status will be updated when the batch completes.", + "tool_calls": None, + "function_call": None, + "provider_specific_fields": { + "batch_job_id": batch_id, + "batch_job_state": "JOB_STATE_PENDING", + "unified_object_id": unified_object_id + } + } + )] + + # Set response cost to 0 initially (will be updated when batch completes) + response_cost = 0.0 + kwargs["response_cost"] = response_cost + kwargs["model"] = model_name + kwargs["batch_id"] = batch_id + kwargs["unified_object_id"] = unified_object_id + kwargs["batch_job_state"] = "JOB_STATE_PENDING" + + logging_obj.model = model_name + logging_obj.model_call_details["model"] = logging_obj.model + logging_obj.model_call_details["response_cost"] = response_cost + logging_obj.model_call_details["batch_id"] = batch_id + + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + else: + # Handle non-successful responses + litellm_model_response = ModelResponse() + litellm_model_response.id = str(uuid.uuid4()) + litellm_model_response.model = "vertex_ai_batch" + litellm_model_response.object = "batch_prediction_job" + litellm_model_response.created = int(start_time.timestamp()) + + # Add error-specific metadata + litellm_model_response.choices = [Choices( + finish_reason="batch_error", + index=0, + message={ + "role": "assistant", + "content": f"Batch prediction job creation failed. Status: {httpx_response.status_code}", + "tool_calls": None, + "function_call": None, + "provider_specific_fields": { + "batch_job_state": "JOB_STATE_FAILED", + "status_code": httpx_response.status_code + } + } + )] + + kwargs["response_cost"] = 0.0 + kwargs["model"] = "vertex_ai_batch" + kwargs["batch_job_state"] = "JOB_STATE_FAILED" + + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + + except Exception as e: + verbose_proxy_logger.error(f"Error in batch_prediction_jobs_handler: {e}") + # Return basic response on error + litellm_model_response = ModelResponse() + litellm_model_response.id = str(uuid.uuid4()) + litellm_model_response.model = "vertex_ai_batch" + litellm_model_response.object = "batch_prediction_job" + litellm_model_response.created = int(start_time.timestamp()) + + # Add error-specific metadata + litellm_model_response.choices = [Choices( + finish_reason="batch_error", + index=0, + message={ + "role": "assistant", + "content": f"Error creating batch prediction job: {str(e)}", + "tool_calls": None, + "function_call": None, + "provider_specific_fields": { + "batch_job_state": "JOB_STATE_FAILED", + "error": str(e) + } + } + )] + + kwargs["response_cost"] = 0.0 + kwargs["model"] = "vertex_ai_batch" + kwargs["batch_job_state"] = "JOB_STATE_FAILED" + + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + + @staticmethod + def _store_batch_managed_object( + unified_object_id: str, + batch_object: LiteLLMBatch, + model_object_id: str, + logging_obj: LiteLLMLoggingObj, + **kwargs, + ) -> None: + """ + Store batch managed object for cost tracking. + This will be picked up by the check_batch_cost polling mechanism. + """ + try: + # Get the managed files hook from the logging object + # This is a bit of a hack, but we need access to the proxy logging system + from litellm.proxy.proxy_server import proxy_logging_obj + + managed_files_hook = proxy_logging_obj.get_proxy_hook("managed_files") + if managed_files_hook is not None and hasattr(managed_files_hook, 'store_unified_object_id'): + # Create a mock user API key dict for the managed object storage + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy._types import LitellmUserRoles + user_api_key_dict = UserAPIKeyAuth( + user_id=kwargs.get("user_id", "default-user"), + api_key="", + team_id=None, + team_alias=None, + user_role=LitellmUserRoles.CUSTOMER, # Use proper enum value + user_email=None, + max_budget=None, + spend=0.0, # Set to 0.0 instead of None + models=[], # Set to empty list instead of None + tpm_limit=None, + rpm_limit=None, + budget_duration=None, + budget_reset_at=None, + max_parallel_requests=None, + allowed_model_region=None, + metadata={}, # Set to empty dict instead of None + key_alias=None, + permissions={}, # Set to empty dict instead of None + model_max_budget={}, # Set to empty dict instead of None + model_spend={}, # Set to empty dict instead of None + ) + + # Store the unified object for batch cost tracking + import asyncio + asyncio.create_task( + managed_files_hook.store_unified_object_id( + unified_object_id=unified_object_id, + file_object=batch_object, + litellm_parent_otel_span=None, + model_object_id=model_object_id, + file_purpose="batch", + user_api_key_dict=user_api_key_dict, + ) + ) + + verbose_proxy_logger.info( + f"Stored batch managed object with unified_object_id={unified_object_id}, batch_id={model_object_id}" + ) + else: + verbose_proxy_logger.warning("Managed files hook not available, cannot store batch object for cost tracking") + + except Exception as e: + verbose_proxy_logger.error(f"Error storing batch managed object: {e}") + + @staticmethod + def get_actual_model_id_from_router(model_name: str) -> str: + from litellm.proxy.proxy_server import llm_router + + if llm_router is not None: + # Try to find the model in the router by the extracted model name + extracted_model_name = VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path(model_name) + + # Use the existing get_model_ids method from router + model_ids = llm_router.get_model_ids(model_name=extracted_model_name) + if model_ids and len(model_ids) > 0: + # Use the first model ID found + actual_model_id = model_ids[0] + verbose_proxy_logger.info(f"Found model ID in router: {actual_model_id}") + return actual_model_id + else: + # Fallback to constructed model name + actual_model_id = extracted_model_name + verbose_proxy_logger.warning(f"Model not found in router, using constructed name: {actual_model_id}") + return actual_model_id + else: + # Fallback if router is not available + extracted_model_name = VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path(model_name) + verbose_proxy_logger.warning(f"Router not available, using constructed model name: {extracted_model_name}") + return extracted_model_name + diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index a819c429f103..0c0ade848eaa 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -40,6 +40,7 @@ def __init__(self): "predict", "rawPredict", "streamRawPredict", + "batchPredictionJobs", ] # Anthropic diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1360433e3213..696edfb443a7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4079,8 +4079,12 @@ async def initialize_scheduled_background_jobs( "interval", seconds=proxy_batch_polling_interval, # these can run infrequently, as batch jobs take time to complete ) + verbose_proxy_logger.info("Batch cost check job scheduled successfully") - except Exception: + except Exception as e: + verbose_proxy_logger.error( + f"Failed to setup batch cost checking: {e}" + ) verbose_proxy_logger.debug( "Checking batch cost for LiteLLM Managed Files is an Enterprise Feature. Skipping..." ) diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_vertex_ai_batch_passthrough.py b/tests/test_litellm/proxy/pass_through_endpoints/test_vertex_ai_batch_passthrough.py new file mode 100644 index 000000000000..68c6bf98cb27 --- /dev/null +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_vertex_ai_batch_passthrough.py @@ -0,0 +1,539 @@ +""" +Test cases for Vertex AI passthrough batch prediction functionality +""" +import base64 +import json +import pytest +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime +from typing import Dict, Any + +from litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler import ( + VertexPassthroughLoggingHandler, +) +from litellm.types.utils import SpecialEnums +from litellm.types.llms.openai import BatchJobStatus + + +class TestVertexAIBatchPassthroughHandler: + """Test cases for Vertex AI batch prediction passthrough functionality""" + + @pytest.fixture + def mock_httpx_response(self): + """Mock httpx response for batch job creation""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "name": "projects/test-project/locations/us-central1/batchPredictionJobs/123456789", + "displayName": "litellm-vertex-batch-test", + "model": "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash", + "createTime": "2024-01-01T00:00:00Z", + "state": "JOB_STATE_PENDING", + "inputConfig": { + "gcsSource": { + "uris": ["gs://test-bucket/input.jsonl"] + }, + "instancesFormat": "jsonl" + }, + "outputConfig": { + "gcsDestination": { + "outputUriPrefix": "gs://test-bucket/output/" + }, + "predictionsFormat": "jsonl" + } + } + return mock_response + + @pytest.fixture + def mock_logging_obj(self): + """Mock logging object""" + mock = Mock() + mock.litellm_call_id = "test-call-id-123" + mock.model_call_details = {} + mock.optional_params = {} + return mock + + @pytest.fixture + def mock_managed_files_hook(self): + """Mock managed files hook""" + mock_hook = Mock() + mock_hook.afile_content.return_value = Mock(content=b'{"test": "data"}') + return mock_hook + + def test_batch_prediction_jobs_handler_success(self, mock_httpx_response, mock_logging_obj): + """Test successful batch job creation and tracking""" + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.verbose_proxy_logger') as mock_logger: + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.VertexPassthroughLoggingHandler.get_actual_model_id_from_router') as mock_get_model_id: + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.VertexPassthroughLoggingHandler._store_batch_managed_object') as mock_store: + with patch('litellm.llms.vertex_ai.batches.transformation.VertexAIBatchTransformation') as mock_transformation: + + # Setup mocks + mock_get_model_id.return_value = "vertex_ai/gemini-1.5-flash" + mock_transformation.transform_vertex_ai_batch_response_to_openai_batch_response.return_value = { + "id": "123456789", + "object": "batch", + "status": "validating", + "created_at": 1704067200, + "input_file_id": "file-123", + "output_file_id": "file-456", + "error_file_id": None, + "completion_window": "24hrs" + } + mock_transformation._get_batch_id_from_vertex_ai_batch_response.return_value = "123456789" + + # Test the handler + result = VertexPassthroughLoggingHandler.batch_prediction_jobs_handler( + httpx_response=mock_httpx_response, + logging_obj=mock_logging_obj, + url_route="/v1/projects/test-project/locations/us-central1/batchPredictionJobs", + result="success", + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + user_api_key_dict={"user_id": "test-user"} + ) + + # Verify the result + assert result is not None + assert "kwargs" in result + assert result["kwargs"]["model"] == "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash" + assert result["kwargs"]["batch_id"] == "123456789" + + # Verify mocks were called + mock_get_model_id.assert_called_once() + mock_store.assert_called_once() + + def test_batch_prediction_jobs_handler_failure(self, mock_logging_obj): + """Test batch job creation failure handling""" + # Mock failed response + mock_httpx_response = Mock() + mock_httpx_response.status_code = 400 + mock_httpx_response.json.return_value = {"error": "Invalid request"} + + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.verbose_proxy_logger') as mock_logger: + # Test the handler with failed response + result = VertexPassthroughLoggingHandler.batch_prediction_jobs_handler( + httpx_response=mock_httpx_response, + logging_obj=mock_logging_obj, + url_route="/v1/projects/test-project/locations/us-central1/batchPredictionJobs", + result="error", + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + user_api_key_dict={"user_id": "test-user"} + ) + + # Should return a structured response for failed responses + assert result is not None + assert "result" in result + assert "kwargs" in result + assert result["result"].choices[0].finish_reason == "batch_error" + assert result["kwargs"]["batch_job_state"] == "JOB_STATE_FAILED" + + def test_get_actual_model_id_from_router_with_router(self): + """Test getting model ID when router is available""" + with patch('litellm.proxy.proxy_server.llm_router') as mock_router: + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path') as mock_extract: + + # Setup mocks + mock_router.get_model_ids.return_value = ["vertex_ai/gemini-1.5-flash"] + mock_extract.return_value = "gemini-1.5-flash" + + # Test the method + result = VertexPassthroughLoggingHandler.get_actual_model_id_from_router( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash" + ) + + # Verify result + assert result == "vertex_ai/gemini-1.5-flash" + mock_router.get_model_ids.assert_called_once_with(model_name="gemini-1.5-flash") + + def test_get_actual_model_id_from_router_without_router(self): + """Test getting model ID when router is not available""" + with patch('litellm.proxy.proxy_server.llm_router', None): + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path') as mock_extract: + + # Setup mocks + mock_extract.return_value = "gemini-1.5-flash" + + # Test the method + result = VertexPassthroughLoggingHandler.get_actual_model_id_from_router( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash" + ) + + # Verify result + assert result == "gemini-1.5-flash" + + def test_get_actual_model_id_from_router_model_not_found(self): + """Test getting model ID when model is not found in router""" + with patch('litellm.proxy.proxy_server.llm_router') as mock_router: + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path') as mock_extract: + + # Setup mocks - router returns empty list + mock_router.get_model_ids.return_value = [] + mock_extract.return_value = "gemini-1.5-flash" + + # Test the method + result = VertexPassthroughLoggingHandler.get_actual_model_id_from_router( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash" + ) + + # Verify result - should fallback to extracted model name + assert result == "gemini-1.5-flash" + + def test_unified_object_id_generation(self): + """Test unified object ID generation for batch tracking""" + model_id = "vertex_ai/gemini-1.5-flash" + batch_id = "123456789" + + # Generate the expected unified ID + unified_id_string = SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format(model_id, batch_id) + expected_unified_id = base64.urlsafe_b64encode(unified_id_string.encode()).decode().rstrip("=") + + # Test the generation + actual_unified_id = base64.urlsafe_b64encode(unified_id_string.encode()).decode().rstrip("=") + + assert actual_unified_id == expected_unified_id + assert isinstance(actual_unified_id, str) + assert len(actual_unified_id) > 0 + + def test_store_batch_managed_object(self, mock_logging_obj, mock_managed_files_hook): + """Test storing batch managed object for cost tracking""" + with patch('litellm.proxy.proxy_server.proxy_logging_obj') as mock_proxy_logging_obj: + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.verbose_proxy_logger') as mock_logger: + + # Setup mock proxy logging obj + mock_proxy_logging_obj.get_proxy_hook.return_value = mock_managed_files_hook + + # Test data + unified_object_id = "test-unified-id" + batch_object = { + "id": "123456789", + "object": "batch", + "status": "validating" + } + model_object_id = "123456789" + + # Test the method + VertexPassthroughLoggingHandler._store_batch_managed_object( + unified_object_id=unified_object_id, + batch_object=batch_object, + model_object_id=model_object_id, + logging_obj=mock_logging_obj, + user_api_key_dict={"user_id": "test-user"} + ) + + # Verify the managed files hook was called + mock_managed_files_hook.store_unified_object_id.assert_called_once() + + def test_batch_cost_calculation_integration(self): + """Test integration with batch cost calculation""" + from litellm.batches.batch_utils import calculate_vertex_ai_batch_cost_and_usage + + # Mock Vertex AI batch responses + vertex_ai_batch_responses = [ + { + "status": "JOB_STATE_SUCCEEDED", + "response": { + "candidates": [ + { + "content": { + "parts": [ + {"text": "Hello, world!"} + ] + } + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 5, + "totalTokenCount": 15 + } + } + } + ] + + with patch('litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.VertexGeminiConfig') as mock_config: + with patch('litellm.completion_cost') as mock_completion_cost: + + # Setup mocks + mock_config.return_value._transform_google_generate_content_to_openai_model_response.return_value = Mock( + usage=Mock(total_tokens=15, prompt_tokens=10, completion_tokens=5) + ) + mock_completion_cost.return_value = 0.001 + + # Test the cost calculation + total_cost, usage = calculate_vertex_ai_batch_cost_and_usage( + vertex_ai_batch_responses, + model_name="gemini-1.5-flash" + ) + + # Verify results + assert total_cost == 0.001 + assert usage.total_tokens == 15 + assert usage.prompt_tokens == 10 + assert usage.completion_tokens == 5 + + def test_batch_response_transformation(self): + """Test transformation of Vertex AI batch responses to OpenAI format""" + from litellm.llms.vertex_ai.batches.transformation import VertexAIBatchTransformation + + # Mock Vertex AI batch response + vertex_ai_response = { + "name": "projects/test-project/locations/us-central1/batchPredictionJobs/123456789", + "displayName": "test-batch", + "model": "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash", + "createTime": "2024-01-01T00:00:00.000Z", + "state": "JOB_STATE_SUCCEEDED" + } + + # Test transformation + result = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + vertex_ai_response + ) + + # Verify the transformation + assert result["id"] == "123456789" + assert result["object"] == "batch" + assert result["status"] == "completed" # JOB_STATE_SUCCEEDED should map to completed + + def test_batch_id_extraction(self): + """Test extraction of batch ID from Vertex AI response""" + from litellm.llms.vertex_ai.batches.transformation import VertexAIBatchTransformation + + # Test various batch ID formats + test_cases = [ + "projects/123/locations/us-central1/batchPredictionJobs/456789", + "projects/abc/locations/europe-west1/batchPredictionJobs/def123", + "batchPredictionJobs/999", + "invalid-format" + ] + + expected_results = ["456789", "def123", "999", "invalid-format"] + + for test_case, expected in zip(test_cases, expected_results): + result = VertexAIBatchTransformation._get_batch_id_from_vertex_ai_batch_response( + {"name": test_case} + ) + assert result == expected + + def test_model_name_extraction_from_vertex_path(self): + """Test extraction of model name from Vertex AI path""" + from litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler import ( + VertexPassthroughLoggingHandler + ) + + # Test various model path formats + test_cases = [ + "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash", + "projects/abc/locations/europe-west1/publishers/google/models/gemini-2.0-flash", + "publishers/google/models/gemini-pro", + "invalid-path" + ] + + expected_results = ["gemini-1.5-flash", "gemini-2.0-flash", "gemini-pro", "invalid-path"] + + for test_case, expected in zip(test_cases, expected_results): + result = VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path(test_case) + assert result == expected + + @pytest.mark.asyncio + async def test_batch_completion_workflow(self, mock_httpx_response, mock_logging_obj, mock_managed_files_hook): + """Test the complete batch completion workflow""" + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.verbose_proxy_logger') as mock_logger: + with patch('litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler.VertexPassthroughLoggingHandler.get_actual_model_id_from_router') as mock_get_model_id: + with patch('litellm.proxy.proxy_server.proxy_logging_obj') as mock_proxy_logging_obj: + mock_proxy_logging_obj.get_proxy_hook.return_value = mock_managed_files_hook + with patch('litellm.llms.vertex_ai.batches.transformation.VertexAIBatchTransformation') as mock_transformation: + + # Setup mocks + mock_get_model_id.return_value = "vertex_ai/gemini-1.5-flash" + mock_transformation.transform_vertex_ai_batch_response_to_openai_batch_response.return_value = { + "id": "123456789", + "object": "batch", + "status": "completed", + "created_at": 1704067200, + "input_file_id": "file-123", + "output_file_id": "file-456", + "error_file_id": None, + "completion_window": "24hrs" + } + mock_transformation._get_batch_id_from_vertex_ai_batch_response.return_value = "123456789" + + # Test the complete workflow + result = VertexPassthroughLoggingHandler.batch_prediction_jobs_handler( + httpx_response=mock_httpx_response, + logging_obj=mock_logging_obj, + url_route="/v1/projects/test-project/locations/us-central1/batchPredictionJobs", + result="success", + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + user_api_key_dict={"user_id": "test-user"} + ) + + # Verify the complete workflow + assert result is not None + assert "kwargs" in result + assert result["kwargs"]["model"] == "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash" + assert result["kwargs"]["batch_id"] == "123456789" + + # Verify all mocks were called + mock_get_model_id.assert_called_once() + mock_transformation.transform_vertex_ai_batch_response_to_openai_batch_response.assert_called_once() + # Note: store_unified_object_id is called asynchronously, so we can't easily verify it in this test + + +class TestVertexAIBatchCostCalculation: + """Test cases for Vertex AI batch cost calculation functionality""" + + def test_calculate_vertex_ai_batch_cost_and_usage_success(self): + """Test successful batch cost and usage calculation""" + from litellm.batches.batch_utils import calculate_vertex_ai_batch_cost_and_usage + + # Mock successful batch responses + vertex_ai_batch_responses = [ + { + "status": "JOB_STATE_SUCCEEDED", + "response": { + "candidates": [ + { + "content": { + "parts": [ + {"text": "Hello, world!"} + ] + } + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 5, + "totalTokenCount": 15 + } + } + }, + { + "status": "JOB_STATE_SUCCEEDED", + "response": { + "candidates": [ + { + "content": { + "parts": [ + {"text": "How are you?"} + ] + } + } + ], + "usageMetadata": { + "promptTokenCount": 8, + "candidatesTokenCount": 3, + "totalTokenCount": 11 + } + } + } + ] + + with patch('litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.VertexGeminiConfig') as mock_config: + with patch('litellm.completion_cost') as mock_completion_cost: + + # Setup mocks + mock_model_response = Mock() + mock_model_response.usage = Mock(total_tokens=15, prompt_tokens=10, completion_tokens=5) + mock_config.return_value._transform_google_generate_content_to_openai_model_response.return_value = mock_model_response + mock_completion_cost.return_value = 0.001 + + # Test the calculation + total_cost, usage = calculate_vertex_ai_batch_cost_and_usage( + vertex_ai_batch_responses, + model_name="gemini-1.5-flash" + ) + + # Verify results + assert total_cost == 0.002 # 2 responses * 0.001 each + assert usage.total_tokens == 30 # 15 + 15 + assert usage.prompt_tokens == 20 # 10 + 10 + assert usage.completion_tokens == 10 # 5 + 5 + + def test_calculate_vertex_ai_batch_cost_and_usage_with_failed_responses(self): + """Test batch cost calculation with some failed responses""" + from litellm.batches.batch_utils import calculate_vertex_ai_batch_cost_and_usage + + # Mock batch responses with some failures + vertex_ai_batch_responses = [ + { + "status": "JOB_STATE_SUCCEEDED", + "response": { + "candidates": [ + { + "content": { + "parts": [ + {"text": "Hello, world!"} + ] + } + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 5, + "totalTokenCount": 15 + } + } + }, + { + "status": "JOB_STATE_FAILED", # Failed response + "response": None + }, + { + "status": "JOB_STATE_SUCCEEDED", + "response": { + "candidates": [ + { + "content": { + "parts": [ + {"text": "How are you?"} + ] + } + } + ], + "usageMetadata": { + "promptTokenCount": 8, + "candidatesTokenCount": 3, + "totalTokenCount": 11 + } + } + } + ] + + with patch('litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.VertexGeminiConfig') as mock_config: + with patch('litellm.completion_cost') as mock_completion_cost: + + # Setup mocks + mock_model_response = Mock() + mock_model_response.usage = Mock(total_tokens=15, prompt_tokens=10, completion_tokens=5) + mock_config.return_value._transform_google_generate_content_to_openai_model_response.return_value = mock_model_response + mock_completion_cost.return_value = 0.001 + + # Test the calculation + total_cost, usage = calculate_vertex_ai_batch_cost_and_usage( + vertex_ai_batch_responses, + model_name="gemini-1.5-flash" + ) + + # Verify results - should only process successful responses + assert total_cost == 0.002 # 2 successful responses * 0.001 each + assert usage.total_tokens == 30 # 15 + 15 + assert usage.prompt_tokens == 20 # 10 + 10 + assert usage.completion_tokens == 10 # 5 + 5 + + def test_calculate_vertex_ai_batch_cost_and_usage_empty_responses(self): + """Test batch cost calculation with empty response list""" + from litellm.batches.batch_utils import calculate_vertex_ai_batch_cost_and_usage + + # Test with empty list + total_cost, usage = calculate_vertex_ai_batch_cost_and_usage([], model_name="gemini-1.5-flash") + + # Verify results + assert total_cost == 0.0 + assert usage.total_tokens == 0 + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0