Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions docs/my-website/docs/vertex_batch_passthrough.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';

# /batchPredictionJobs

LiteLLM supports Vertex AI batch prediction jobs through passthrough endpoints, allowing you to create and manage batch jobs directly through the proxy server.

## Features

- **Batch Job Creation**: Create batch prediction jobs using Vertex AI models
- **Cost Tracking**: Automatic cost calculation and usage tracking for batch operations
- **Status Monitoring**: Track job status and retrieve results
- **Model Support**: Works with all supported Vertex AI models (Gemini, Text Embedding)

## Cost Tracking Support

| Feature | Supported | Notes |
|---------|-----------|-------|
| Cost Tracking || Automatic cost calculation for batch operations |
| Usage Monitoring || Track token usage and costs across batch jobs |
| Logging || Supported |

## Quick Start

1. **Configure your model** in the proxy configuration:

```yaml
model_list:
- model_name: gemini-1.5-flash
litellm_params:
model: vertex_ai/gemini-1.5-flash
vertex_project: your-project-id
vertex_location: us-central1
vertex_credentials: path/to/service-account.json
```
2. **Create a batch job**:
```bash
curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \
-H "Authorization: Bearer your-api-key" \
-H "Content-Type: application/json" \
-d '{
"displayName": "my-batch-job",
"model": "projects/your-project/locations/us-central1/publishers/google/models/gemini-1.5-flash",
"inputConfig": {
"gcsSource": {
"uris": ["gs://my-bucket/input.jsonl"]
},
"instancesFormat": "jsonl"
},
"outputConfig": {
"gcsDestination": {
"outputUriPrefix": "gs://my-bucket/output/"
},
"predictionsFormat": "jsonl"
}
}'
```

3. **Monitor job status**:

```bash
curl -X GET "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs/job-id" \
-H "Authorization: Bearer your-api-key"
```

## Model Configuration

When configuring models for batch operations, use these naming conventions:

- **`model_name`**: Base model name (e.g., `gemini-1.5-flash`)
- **`model`**: Full LiteLLM identifier (e.g., `vertex_ai/gemini-1.5-flash`)

## Supported Models

- `gemini-1.5-flash` / `vertex_ai/gemini-1.5-flash`
- `gemini-1.5-pro` / `vertex_ai/gemini-1.5-pro`
- `gemini-2.0-flash` / `vertex_ai/gemini-2.0-flash`
- `gemini-2.0-pro` / `vertex_ai/gemini-2.0-pro`

## Advanced Usage

### Batch Job with Custom Parameters

```bash
curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \
-H "Authorization: Bearer your-api-key" \
-H "Content-Type: application/json" \
-d '{
"displayName": "advanced-batch-job",
"model": "projects/your-project/locations/us-central1/publishers/google/models/gemini-1.5-pro",
"inputConfig": {
"gcsSource": {
"uris": ["gs://my-bucket/advanced-input.jsonl"]
},
"instancesFormat": "jsonl"
},
"outputConfig": {
"gcsDestination": {
"outputUriPrefix": "gs://my-bucket/advanced-output/"
},
"predictionsFormat": "jsonl"
},
"labels": {
"environment": "production",
"team": "ml-engineering"
}
}'
```

### List All Batch Jobs

```bash
curl -X GET "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \
-H "Authorization: Bearer your-api-key"
```

### Cancel a Batch Job

```bash
curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs/job-id:cancel" \
-H "Authorization: Bearer your-api-key"
```

## Cost Tracking Details

LiteLLM provides comprehensive cost tracking for Vertex AI batch operations:

- **Token Usage**: Tracks input and output tokens for each batch request
- **Cost Calculation**: Automatically calculates costs based on current Vertex AI pricing
- **Usage Aggregation**: Aggregates costs across all requests in a batch job
- **Real-time Monitoring**: Monitor costs as batch jobs progress

The cost tracking works seamlessly with the `generateContent` API and provides detailed insights into your batch processing expenses.

## Error Handling

Common error scenarios and their solutions:

| Error | Description | Solution |
|-------|-------------|----------|
| `INVALID_ARGUMENT` | Invalid model or configuration | Verify model name and project settings |
| `PERMISSION_DENIED` | Insufficient permissions | Check Vertex AI IAM roles |
| `RESOURCE_EXHAUSTED` | Quota exceeded | Check Vertex AI quotas and limits |
| `NOT_FOUND` | Job or resource not found | Verify job ID and project configuration |

## Best Practices

1. **Use appropriate batch sizes**: Balance between processing efficiency and resource usage
2. **Monitor job status**: Regularly check job status to handle failures promptly
3. **Set up alerts**: Configure monitoring for job completion and failures
4. **Optimize costs**: Use cost tracking to identify optimization opportunities
5. **Test with small batches**: Validate your setup with small test batches first

## Related Documentation

- [Vertex AI Provider Documentation](./vertex.md)
- [General Batches API Documentation](../batches.md)
- [Cost Tracking and Monitoring](../observability/telemetry.md)
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ async def check_batch_cost(self):
"file_purpose": "batch",
}
)

completed_jobs = []

for job in jobs:
Expand Down Expand Up @@ -139,7 +138,7 @@ async def check_batch_cost(self):
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
litellm_model_name = deployment_info.litellm_params.model

_, llm_provider, _, _ = get_llm_provider(
model_name, llm_provider, _, _ = get_llm_provider(
model=litellm_model_name,
custom_llm_provider=custom_llm_provider,
)
Expand All @@ -148,9 +147,9 @@ async def check_batch_cost(self):
await calculate_batch_cost_and_usage(
file_content_dictionary=file_content_as_dict,
custom_llm_provider=llm_provider, # type: ignore
model_name=model_name,
)
)

logging_obj = LiteLLMLogging(
model=batch_models[0],
messages=[{"role": "user", "content": "<retrieve_batch>"}],
Expand Down
115 changes: 108 additions & 7 deletions litellm/batches/batch_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List, Literal, Tuple
from typing import Any, List, Literal, Tuple, Optional

import litellm
from litellm._logging import verbose_logger
Expand All @@ -10,28 +10,30 @@
async def calculate_batch_cost_and_usage(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
model_name: Optional[str] = None,
) -> Tuple[float, Usage, List[str]]:
"""
Calculate the cost and usage of a batch
"""
# Calculate costs and usage
batch_cost = _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
model_name=model_name,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
model_name=model_name,
)

batch_models = _get_batch_models_from_file_content(file_content_dictionary)
batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name)

return batch_cost, batch_usage, batch_models


async def _handle_completed_batch(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
model_name: Optional[str] = None,
) -> Tuple[float, Usage, List[str]]:
"""Helper function to process a completed batch and handle logging"""
# Get batch results
Expand All @@ -43,23 +45,28 @@ async def _handle_completed_batch(
batch_cost = _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
model_name=model_name,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
model_name=model_name,
)

batch_models = _get_batch_models_from_file_content(file_content_dictionary)
batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name)

return batch_cost, batch_usage, batch_models


def _get_batch_models_from_file_content(
file_content_dictionary: List[dict],
model_name: Optional[str] = None,
) -> List[str]:
"""
Get the models from the file content
"""
if model_name:
return [model_name]
batch_models = []
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
Expand All @@ -73,12 +80,18 @@ def _get_batch_models_from_file_content(
def _batch_cost_calculator(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
model_name: Optional[str] = None,
) -> float:
"""
Calculate the cost of a batch based on the output file id
"""
if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval")
# Handle Vertex AI with specialized method
if custom_llm_provider == "vertex_ai" and model_name:
batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name)
verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost)
return batch_cost

# For other providers, use the existing logic
total_cost = _get_batch_job_cost_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
Expand All @@ -87,6 +100,87 @@ def _batch_cost_calculator(
return total_cost


def calculate_vertex_ai_batch_cost_and_usage(
vertex_ai_batch_responses: List[dict],
model_name: Optional[str] = None,
) -> Tuple[float, Usage]:
"""
Calculate both cost and usage from Vertex AI batch responses
"""
total_cost = 0.0
total_tokens = 0
prompt_tokens = 0
completion_tokens = 0

for response in vertex_ai_batch_responses:
if response.get("status") == "JOB_STATE_SUCCEEDED": # Check if response was successful
# Transform Vertex AI response to OpenAI format if needed
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
from litellm import ModelResponse
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.types.utils import CallTypes
from litellm._uuid import uuid
import httpx
import time

# Create required arguments for the transformation method
model_response = ModelResponse()

# Ensure model_name is not None
actual_model_name = model_name or "gemini-2.5-flash"

# Create a real LiteLLM logging object
logging_obj = Logging(
model=actual_model_name,
messages=[{"role": "user", "content": "batch_request"}],
stream=False,
call_type=CallTypes.aretrieve_batch,
start_time=time.time(),
litellm_call_id="batch_" + str(uuid.uuid4()),
function_id="batch_processing",
litellm_trace_id=str(uuid.uuid4()),
kwargs={"optional_params": {}}
)

# Add the optional_params attribute that the Vertex AI transformation expects
logging_obj.optional_params = {}
raw_response = httpx.Response(200) # Mock response object

openai_format_response = VertexGeminiConfig()._transform_google_generate_content_to_openai_model_response(
completion_response=response["response"],
model_response=model_response,
model=actual_model_name,
logging_obj=logging_obj,
raw_response=raw_response,
)

# Calculate cost using existing function
cost = litellm.completion_cost(
completion_response=openai_format_response,
custom_llm_provider="vertex_ai",
call_type=CallTypes.aretrieve_batch.value,
)
total_cost += cost

# Extract usage from the transformed response
if hasattr(openai_format_response, 'usage') and openai_format_response.usage:
usage = openai_format_response.usage
else:
# Fallback: create usage from response dict
response_dict = openai_format_response.dict() if hasattr(openai_format_response, 'dict') else {}
usage = _get_batch_job_usage_from_response_body(response_dict)

total_tokens += usage.total_tokens
prompt_tokens += usage.prompt_tokens
completion_tokens += usage.completion_tokens

return total_cost, Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)


async def _get_batch_output_file_content_as_dictionary(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
Expand Down Expand Up @@ -157,10 +251,17 @@ def _get_batch_job_cost_from_file_content(
def _get_batch_job_total_usage_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
model_name: Optional[str] = None,
) -> Usage:
"""
Get the tokens of a batch job from the file content
"""
# Handle Vertex AI with specialized method
if custom_llm_provider == "vertex_ai" and model_name:
_, batch_usage = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name)
return batch_usage

# For other providers, use the existing logic
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
Expand Down
Loading
Loading