-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathtest_extra_headers.py
93 lines (81 loc) · 2.97 KB
/
test_extra_headers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import pytest
from agents import (
OpenAIChatCompletionsModel,
OpenAIResponsesModel,
ModelSettings,
ModelTracing
)
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_extra_headers_passed_to_openai_responses_model():
"""
Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client.
"""
called_kwargs = {}
class DummyResponses:
async def create(self, **kwargs):
nonlocal called_kwargs
called_kwargs = kwargs
class DummyResponse:
id = "dummy"
output = []
usage = type("Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})()
return DummyResponse()
class DummyClient:
def __init__(self):
self.responses = DummyResponses()
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient())
extra_headers = {"X-Test-Header": "test-value"}
await model.get_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(extra_headers=extra_headers),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)
assert "extra_headers" in called_kwargs
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_extra_headers_passed_to_openai_client():
"""
Ensure extra_headers in ModelSettings is passed to the OpenAI client.
"""
called_kwargs = {}
class DummyCompletions:
async def create(self, **kwargs):
nonlocal called_kwargs
called_kwargs = kwargs
msg = ChatCompletionMessage(role="assistant", content="Hello")
choice = Choice(index=0, finish_reason="stop", message=msg)
return ChatCompletion(
id="resp-id",
created=0,
model="fake",
object="chat.completion",
choices=[choice],
usage=None,
)
class DummyClient:
def __init__(self):
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
self.base_url = "https://api.openai.com"
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient())
extra_headers = {"X-Test-Header": "test-value"}
await model.get_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(extra_headers=extra_headers),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)
assert "extra_headers" in called_kwargs
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"