1
+ import pytest
2
+ from agents import (
3
+ OpenAIChatCompletionsModel ,
4
+ OpenAIResponsesModel ,
5
+ ModelSettings ,
6
+ ModelTracing
7
+ )
8
+ from openai .types .chat .chat_completion import ChatCompletion , Choice
9
+ from openai .types .chat .chat_completion_message import ChatCompletionMessage
10
+
11
+ @pytest .mark .allow_call_model_methods
12
+ @pytest .mark .asyncio
13
+ async def test_extra_headers_passed_to_openai_responses_model ():
14
+ """
15
+ Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client.
16
+ """
17
+ called_kwargs = {}
18
+
19
+ class DummyResponses :
20
+ async def create (self , ** kwargs ):
21
+ nonlocal called_kwargs
22
+ called_kwargs = kwargs
23
+ class DummyResponse :
24
+ id = "dummy"
25
+ output = []
26
+ usage = type ("Usage" , (), {"input_tokens" : 0 , "output_tokens" : 0 , "total_tokens" : 0 })()
27
+ return DummyResponse ()
28
+
29
+ class DummyClient :
30
+ def __init__ (self ):
31
+ self .responses = DummyResponses ()
32
+
33
+
34
+
35
+ model = OpenAIResponsesModel (model = "gpt-4" , openai_client = DummyClient ())
36
+ extra_headers = {"X-Test-Header" : "test-value" }
37
+ await model .get_response (
38
+ system_instructions = None ,
39
+ input = "hi" ,
40
+ model_settings = ModelSettings (extra_headers = extra_headers ),
41
+ tools = [],
42
+ output_schema = None ,
43
+ handoffs = [],
44
+ tracing = ModelTracing .DISABLED ,
45
+ previous_response_id = None ,
46
+ )
47
+ assert "extra_headers" in called_kwargs
48
+ assert called_kwargs ["extra_headers" ]["X-Test-Header" ] == "test-value"
49
+
50
+
51
+
52
+ @pytest .mark .allow_call_model_methods
53
+ @pytest .mark .asyncio
54
+ async def test_extra_headers_passed_to_openai_client ():
55
+ """
56
+ Ensure extra_headers in ModelSettings is passed to the OpenAI client.
57
+ """
58
+ called_kwargs = {}
59
+
60
+ class DummyCompletions :
61
+ async def create (self , ** kwargs ):
62
+ nonlocal called_kwargs
63
+ called_kwargs = kwargs
64
+ # Return a real ChatCompletion object as expected by get_response
65
+ msg = ChatCompletionMessage (role = "assistant" , content = "Hello" )
66
+ choice = Choice (index = 0 , finish_reason = "stop" , message = msg )
67
+ return ChatCompletion (
68
+ id = "resp-id" ,
69
+ created = 0 ,
70
+ model = "fake" ,
71
+ object = "chat.completion" ,
72
+ choices = [choice ],
73
+ usage = None ,
74
+ )
75
+
76
+ class DummyClient :
77
+ def __init__ (self ):
78
+ self .chat = type ("_Chat" , (), {"completions" : DummyCompletions ()})()
79
+ self .base_url = "https://api.openai.com"
80
+
81
+ model = OpenAIChatCompletionsModel (model = "gpt-4" , openai_client = DummyClient ())
82
+ extra_headers = {"X-Test-Header" : "test-value" }
83
+ await model .get_response (
84
+ system_instructions = None ,
85
+ input = "hi" ,
86
+ model_settings = ModelSettings (extra_headers = extra_headers ),
87
+ tools = [],
88
+ output_schema = None ,
89
+ handoffs = [],
90
+ tracing = ModelTracing .DISABLED ,
91
+ previous_response_id = None ,
92
+ )
93
+ assert "extra_headers" in called_kwargs
94
+ assert called_kwargs ["extra_headers" ]["X-Test-Header" ] == "test-value"
95
+
96
+
97
+ @pytest .mark .allow_call_model_methods
98
+ @pytest .mark .asyncio
99
+ async def test_extra_headers_passed_to_litellm_model (monkeypatch ):
100
+ """
101
+ Ensure extra_headers in ModelSettings is passed to the LitellmModel.
102
+ """
103
+ from agents .extensions .models .litellm_model import LitellmModel
104
+
105
+ called_kwargs = {}
106
+
107
+ async def dummy_acompletion (* args , ** kwargs ):
108
+ nonlocal called_kwargs
109
+ called_kwargs = kwargs
110
+ # Return a minimal object to trigger downstream error after call
111
+ return None
112
+
113
+ monkeypatch .setattr ("agents.extensions.models.litellm_model.litellm.acompletion" , dummy_acompletion )
114
+
115
+ model = LitellmModel (model = "any-model" )
116
+ extra_headers = {"X-Test-Header" : "test-value" }
117
+ with pytest .raises (Exception ): # We expect an error, but we only care about the call
118
+ await model .get_response (
119
+ system_instructions = None ,
120
+ input = "hi" ,
121
+ model_settings = ModelSettings (extra_headers = extra_headers ),
122
+ tools = [],
123
+ output_schema = None ,
124
+ handoffs = [],
125
+ tracing = ModelTracing .DISABLED ,
126
+ previous_response_id = None ,
127
+ )
128
+ assert "extra_headers" in called_kwargs
129
+ assert called_kwargs ["extra_headers" ]["X-Test-Header" ] == "test-value"
0 commit comments