Skip to content

Commit e102319

Browse files
committed
Adding extra_headers parameters to ModelSettings
1 parent 5639606 commit e102319

File tree

5 files changed

+137
-4
lines changed

5 files changed

+137
-4
lines changed

src/agents/extensions/models/litellm_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ async def _fetch_response(
286286
stream=stream,
287287
stream_options=stream_options,
288288
reasoning_effort=reasoning_effort,
289-
extra_headers=HEADERS,
289+
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
290290
api_key=self.api_key,
291291
base_url=self.base_url,
292292
**extra_kwargs,

src/agents/model_settings.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass, fields, replace
44
from typing import Literal
55

6-
from openai._types import Body, Query
6+
from openai._types import Body, Query, Headers
77
from openai.types.shared import Reasoning
88

99

@@ -67,6 +67,10 @@ class ModelSettings:
6767
"""Additional body fields to provide with the request.
6868
Defaults to None if not provided."""
6969

70+
extra_headers: Headers | None = None
71+
"""Additional headers to provide with the request.
72+
Defaults to None if not provided."""
73+
7074
def resolve(self, override: ModelSettings | None) -> ModelSettings:
7175
"""Produce a new ModelSettings by overlaying any non-None values from the
7276
override on top of this instance."""

src/agents/models/openai_chatcompletions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ async def _fetch_response(
255255
stream_options=self._non_null_or_not_given(stream_options),
256256
store=self._non_null_or_not_given(store),
257257
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
258-
extra_headers=HEADERS,
258+
extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) },
259259
extra_query=model_settings.extra_query,
260260
extra_body=model_settings.extra_body,
261261
metadata=self._non_null_or_not_given(model_settings.metadata),

src/agents/models/openai_responses.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ async def _fetch_response(
253253
tool_choice=tool_choice,
254254
parallel_tool_calls=parallel_tool_calls,
255255
stream=stream,
256-
extra_headers=_HEADERS,
256+
extra_headers={**_HEADERS, **(model_settings.extra_headers or {})},
257257
extra_query=model_settings.extra_query,
258258
extra_body=model_settings.extra_body,
259259
text=response_format,

tests/test_extra_headers.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import pytest
2+
from agents import (
3+
OpenAIChatCompletionsModel,
4+
OpenAIResponsesModel,
5+
ModelSettings,
6+
ModelTracing
7+
)
8+
from openai.types.chat.chat_completion import ChatCompletion, Choice
9+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
10+
11+
@pytest.mark.allow_call_model_methods
12+
@pytest.mark.asyncio
13+
async def test_extra_headers_passed_to_openai_responses_model():
14+
"""
15+
Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client.
16+
"""
17+
called_kwargs = {}
18+
19+
class DummyResponses:
20+
async def create(self, **kwargs):
21+
nonlocal called_kwargs
22+
called_kwargs = kwargs
23+
class DummyResponse:
24+
id = "dummy"
25+
output = []
26+
usage = type("Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})()
27+
return DummyResponse()
28+
29+
class DummyClient:
30+
def __init__(self):
31+
self.responses = DummyResponses()
32+
33+
34+
35+
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient())
36+
extra_headers = {"X-Test-Header": "test-value"}
37+
await model.get_response(
38+
system_instructions=None,
39+
input="hi",
40+
model_settings=ModelSettings(extra_headers=extra_headers),
41+
tools=[],
42+
output_schema=None,
43+
handoffs=[],
44+
tracing=ModelTracing.DISABLED,
45+
previous_response_id=None,
46+
)
47+
assert "extra_headers" in called_kwargs
48+
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
49+
50+
51+
52+
@pytest.mark.allow_call_model_methods
53+
@pytest.mark.asyncio
54+
async def test_extra_headers_passed_to_openai_client():
55+
"""
56+
Ensure extra_headers in ModelSettings is passed to the OpenAI client.
57+
"""
58+
called_kwargs = {}
59+
60+
class DummyCompletions:
61+
async def create(self, **kwargs):
62+
nonlocal called_kwargs
63+
called_kwargs = kwargs
64+
# Return a real ChatCompletion object as expected by get_response
65+
msg = ChatCompletionMessage(role="assistant", content="Hello")
66+
choice = Choice(index=0, finish_reason="stop", message=msg)
67+
return ChatCompletion(
68+
id="resp-id",
69+
created=0,
70+
model="fake",
71+
object="chat.completion",
72+
choices=[choice],
73+
usage=None,
74+
)
75+
76+
class DummyClient:
77+
def __init__(self):
78+
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
79+
self.base_url = "https://api.openai.com"
80+
81+
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient())
82+
extra_headers = {"X-Test-Header": "test-value"}
83+
await model.get_response(
84+
system_instructions=None,
85+
input="hi",
86+
model_settings=ModelSettings(extra_headers=extra_headers),
87+
tools=[],
88+
output_schema=None,
89+
handoffs=[],
90+
tracing=ModelTracing.DISABLED,
91+
previous_response_id=None,
92+
)
93+
assert "extra_headers" in called_kwargs
94+
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
95+
96+
97+
@pytest.mark.allow_call_model_methods
98+
@pytest.mark.asyncio
99+
async def test_extra_headers_passed_to_litellm_model(monkeypatch):
100+
"""
101+
Ensure extra_headers in ModelSettings is passed to the LitellmModel.
102+
"""
103+
from agents.extensions.models.litellm_model import LitellmModel
104+
105+
called_kwargs = {}
106+
107+
async def dummy_acompletion(*args, **kwargs):
108+
nonlocal called_kwargs
109+
called_kwargs = kwargs
110+
# Return a minimal object to trigger downstream error after call
111+
return None
112+
113+
monkeypatch.setattr("agents.extensions.models.litellm_model.litellm.acompletion", dummy_acompletion)
114+
115+
model = LitellmModel(model="any-model")
116+
extra_headers = {"X-Test-Header": "test-value"}
117+
with pytest.raises(Exception): # We expect an error, but we only care about the call
118+
await model.get_response(
119+
system_instructions=None,
120+
input="hi",
121+
model_settings=ModelSettings(extra_headers=extra_headers),
122+
tools=[],
123+
output_schema=None,
124+
handoffs=[],
125+
tracing=ModelTracing.DISABLED,
126+
previous_response_id=None,
127+
)
128+
assert "extra_headers" in called_kwargs
129+
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"

0 commit comments

Comments
 (0)