diff --git a/src/agents/usage.py b/src/agents/usage.py index 3639cf944..a2b41529e 100644 --- a/src/agents/usage.py +++ b/src/agents/usage.py @@ -4,6 +4,26 @@ from pydantic.dataclasses import dataclass +@dataclass +class RequestUsage: + """Usage details for a single API request.""" + + input_tokens: int + """Input tokens for this individual request.""" + + output_tokens: int + """Output tokens for this individual request.""" + + total_tokens: int + """Total tokens (input + output) for this individual request.""" + + input_tokens_details: InputTokensDetails + """Details about the input tokens for this individual request.""" + + output_tokens_details: OutputTokensDetails + """Details about the output tokens for this individual request.""" + + @dataclass class Usage: requests: int = 0 @@ -27,7 +47,27 @@ class Usage: total_tokens: int = 0 """Total tokens sent and received, across all requests.""" + request_usage_entries: list[RequestUsage] = field(default_factory=list) + """List of RequestUsage entries for accurate per-request cost calculation. + + Each call to `add()` automatically creates an entry in this list if the added usage + represents a new request (i.e., has non-zero tokens). + + Example: + For a run that makes 3 API calls with 100K, 150K, and 80K input tokens each, + the aggregated `input_tokens` would be 330K, but `request_usage_entries` would + preserve the [100K, 150K, 80K] breakdown, which could be helpful for detailed + cost calculation or context window management. + """ + def add(self, other: "Usage") -> None: + """Add another Usage object to this one, aggregating all fields. + + This method automatically preserves request_usage_entries. + + Args: + other: The Usage object to add to this one. + """ self.requests += other.requests if other.requests else 0 self.input_tokens += other.input_tokens if other.input_tokens else 0 self.output_tokens += other.output_tokens if other.output_tokens else 0 @@ -41,3 +81,18 @@ def add(self, other: "Usage") -> None: reasoning_tokens=self.output_tokens_details.reasoning_tokens + other.output_tokens_details.reasoning_tokens ) + + # Automatically preserve request_usage_entries. + # If the other Usage represents a single request with tokens, record it. + if other.requests == 1 and other.total_tokens > 0: + request_usage = RequestUsage( + input_tokens=other.input_tokens, + output_tokens=other.output_tokens, + total_tokens=other.total_tokens, + input_tokens_details=other.input_tokens_details, + output_tokens_details=other.output_tokens_details, + ) + self.request_usage_entries.append(request_usage) + elif other.request_usage_entries: + # If the other Usage already has individual request breakdowns, merge them. + self.request_usage_entries.extend(other.request_usage_entries) diff --git a/tests/test_usage.py b/tests/test_usage.py index 405f99ddf..d0e674111 100644 --- a/tests/test_usage.py +++ b/tests/test_usage.py @@ -1,6 +1,6 @@ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails -from agents.usage import Usage +from agents.usage import RequestUsage, Usage def test_usage_add_aggregates_all_fields(): @@ -50,3 +50,220 @@ def test_usage_add_aggregates_with_none_values(): assert u1.total_tokens == 15 assert u1.input_tokens_details.cached_tokens == 4 assert u1.output_tokens_details.reasoning_tokens == 6 + + +def test_request_usage_creation(): + """Test that RequestUsage is created correctly.""" + request_usage = RequestUsage( + input_tokens=100, + output_tokens=200, + total_tokens=300, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens_details=OutputTokensDetails(reasoning_tokens=20), + ) + + assert request_usage.input_tokens == 100 + assert request_usage.output_tokens == 200 + assert request_usage.total_tokens == 300 + assert request_usage.input_tokens_details.cached_tokens == 10 + assert request_usage.output_tokens_details.reasoning_tokens == 20 + + +def test_usage_add_preserves_single_request(): + """Test that adding a single request Usage creates an RequestUsage entry.""" + u1 = Usage() + u2 = Usage( + requests=1, + input_tokens=100, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens=200, + output_tokens_details=OutputTokensDetails(reasoning_tokens=20), + total_tokens=300, + ) + + u1.add(u2) + + # Should preserve the request usage details + assert len(u1.request_usage_entries) == 1 + request_usage = u1.request_usage_entries[0] + assert request_usage.input_tokens == 100 + assert request_usage.output_tokens == 200 + assert request_usage.total_tokens == 300 + assert request_usage.input_tokens_details.cached_tokens == 10 + assert request_usage.output_tokens_details.reasoning_tokens == 20 + + +def test_usage_add_ignores_zero_token_requests(): + """Test that zero-token requests don't create request_usage_entries.""" + u1 = Usage() + u2 = Usage( + requests=1, + input_tokens=0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens=0, + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + total_tokens=0, + ) + + u1.add(u2) + + # Should not create a request_usage_entry for zero tokens + assert len(u1.request_usage_entries) == 0 + + +def test_usage_add_ignores_multi_request_usage(): + """Test that multi-request Usage objects don't create request_usage_entries.""" + u1 = Usage() + u2 = Usage( + requests=3, # Multiple requests + input_tokens=100, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens=200, + output_tokens_details=OutputTokensDetails(reasoning_tokens=20), + total_tokens=300, + ) + + u1.add(u2) + + # Should not create a request usage entry for multi-request usage + assert len(u1.request_usage_entries) == 0 + + +def test_usage_add_merges_existing_request_usage_entries(): + """Test that existing request_usage_entries are merged when adding Usage objects.""" + # Create first usage with request_usage_entries + u1 = Usage() + u2 = Usage( + requests=1, + input_tokens=100, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens=200, + output_tokens_details=OutputTokensDetails(reasoning_tokens=20), + total_tokens=300, + ) + u1.add(u2) + + # Create second usage with request_usage_entries + u3 = Usage( + requests=1, + input_tokens=50, + input_tokens_details=InputTokensDetails(cached_tokens=5), + output_tokens=75, + output_tokens_details=OutputTokensDetails(reasoning_tokens=10), + total_tokens=125, + ) + + u1.add(u3) + + # Should have both request_usage_entries + assert len(u1.request_usage_entries) == 2 + + # First request + first = u1.request_usage_entries[0] + assert first.input_tokens == 100 + assert first.output_tokens == 200 + assert first.total_tokens == 300 + + # Second request + second = u1.request_usage_entries[1] + assert second.input_tokens == 50 + assert second.output_tokens == 75 + assert second.total_tokens == 125 + + +def test_usage_add_with_pre_existing_request_usage_entries(): + """Test adding Usage objects that already have request_usage_entries.""" + u1 = Usage() + + # Create a usage with request_usage_entries + u2 = Usage( + requests=1, + input_tokens=100, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens=200, + output_tokens_details=OutputTokensDetails(reasoning_tokens=20), + total_tokens=300, + ) + u1.add(u2) + + # Create another usage with request_usage_entries + u3 = Usage( + requests=1, + input_tokens=50, + input_tokens_details=InputTokensDetails(cached_tokens=5), + output_tokens=75, + output_tokens_details=OutputTokensDetails(reasoning_tokens=10), + total_tokens=125, + ) + + # Add u3 to u1 + u1.add(u3) + + # Should have both request_usage_entries + assert len(u1.request_usage_entries) == 2 + assert u1.request_usage_entries[0].input_tokens == 100 + assert u1.request_usage_entries[1].input_tokens == 50 + + +def test_usage_request_usage_entries_default_empty(): + """Test that request_usage_entries defaults to an empty list.""" + u = Usage() + assert u.request_usage_entries == [] + + +def test_anthropic_cost_calculation_scenario(): + """Test a realistic scenario for Sonnet 4.5 cost calculation with 200K token thresholds.""" + # Simulate 3 API calls: 100K, 150K, and 80K input tokens each + # None exceed 200K, so they should all use the lower pricing tier + + usage = Usage() + + # First request: 100K input tokens + req1 = Usage( + requests=1, + input_tokens=100_000, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens=50_000, + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + total_tokens=150_000, + ) + usage.add(req1) + + # Second request: 150K input tokens + req2 = Usage( + requests=1, + input_tokens=150_000, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens=75_000, + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + total_tokens=225_000, + ) + usage.add(req2) + + # Third request: 80K input tokens + req3 = Usage( + requests=1, + input_tokens=80_000, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens=40_000, + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + total_tokens=120_000, + ) + usage.add(req3) + + # Verify aggregated totals + assert usage.requests == 3 + assert usage.input_tokens == 330_000 # 100K + 150K + 80K + assert usage.output_tokens == 165_000 # 50K + 75K + 40K + assert usage.total_tokens == 495_000 # 150K + 225K + 120K + + # Verify request_usage_entries preservation + assert len(usage.request_usage_entries) == 3 + assert usage.request_usage_entries[0].input_tokens == 100_000 + assert usage.request_usage_entries[1].input_tokens == 150_000 + assert usage.request_usage_entries[2].input_tokens == 80_000 + + # All request_usage_entries are under 200K threshold + for req in usage.request_usage_entries: + assert req.input_tokens < 200_000 + assert req.output_tokens < 200_000