Skip to content

Commit 95b16b2

Browse files
committed
_openai_token_param(): picks max_completion_tokens for gpt-5-*.
New _openai_include_sampling_params(): skips temperature/top_p for gpt-5-*. Applied in both sync and async chat.completions code paths.
1 parent e4473ab commit 95b16b2

File tree

2 files changed

+49
-15
lines changed

2 files changed

+49
-15
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ outputs/
4040
Thumbs.db
4141

4242
{new_directory}
43-
43+
.env
4444
# Common misspellings / alternate names for generated artifacts
4545
inferenced/
4646
outptu/

synthetic_data_kit/models/llm_client.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,30 @@ def _init_openai_client(self):
181181
client_kwargs['base_url'] = self.api_base
182182

183183
self.openai_client = OpenAI(**client_kwargs)
184+
185+
def _openai_token_param(self, model_name: Optional[str] = None) -> str:
186+
"""Return the correct token limit parameter name for the given OpenAI model.
187+
188+
Some newer models (e.g., GPT-5 series) do not accept 'max_tokens' and require
189+
'max_completion_tokens' instead. This helper returns the appropriate key.
190+
"""
191+
name = (model_name or self.model or "").lower()
192+
# GPT-5 series require 'max_completion_tokens'
193+
if name.startswith("gpt-5"):
194+
return "max_completion_tokens"
195+
# Default for chat.completions
196+
return "max_tokens"
197+
198+
def _openai_include_sampling_params(self, model_name: Optional[str] = None) -> bool:
199+
"""Determine whether to include sampling params like temperature/top_p.
200+
201+
Some newer models (e.g., GPT-5 series) restrict sampling controls and only
202+
support defaults. For those models, we must omit these parameters.
203+
"""
204+
name = (model_name or self.model or "").lower()
205+
if name.startswith("gpt-5"):
206+
return False
207+
return True
184208

185209
def _check_vllm_server(self) -> tuple:
186210
"""Check if the VLLM server is running and accessible"""
@@ -249,13 +273,18 @@ def _openai_chat_completion(self,
249273
for attempt in range(self.max_retries):
250274
try:
251275
# Create the completion request
252-
response = self.openai_client.chat.completions.create(
253-
model=self.model,
254-
messages=messages,
255-
temperature=temperature,
256-
max_tokens=max_tokens,
257-
top_p=top_p
258-
)
276+
token_param = self._openai_token_param(self.model)
277+
req_kwargs = {
278+
"model": self.model,
279+
"messages": messages,
280+
}
281+
# Include sampling params only if allowed by the model
282+
if self._openai_include_sampling_params(self.model):
283+
req_kwargs["temperature"] = temperature
284+
req_kwargs["top_p"] = top_p
285+
req_kwargs[token_param] = max_tokens
286+
287+
response = self.openai_client.chat.completions.create(**req_kwargs)
259288

260289
if verbose:
261290
logger.info(f"Received response from {self.provider}")
@@ -565,13 +594,18 @@ async def _process_message_async(self,
565594
for attempt in range(self.max_retries):
566595
try:
567596
# Asynchronously call the API
568-
response = await async_client.chat.completions.create(
569-
model=self.model,
570-
messages=messages,
571-
temperature=temperature,
572-
max_tokens=max_tokens,
573-
top_p=top_p
574-
)
597+
token_param = self._openai_token_param(self.model)
598+
req_kwargs = {
599+
"model": self.model,
600+
"messages": messages,
601+
}
602+
# Include sampling params only if allowed by the model
603+
if self._openai_include_sampling_params(self.model):
604+
req_kwargs["temperature"] = temperature
605+
req_kwargs["top_p"] = top_p
606+
req_kwargs[token_param] = max_tokens
607+
608+
response = await async_client.chat.completions.create(**req_kwargs)
575609

576610
if verbose:
577611
logger.info(f"Received response from {self.provider}")

0 commit comments

Comments
 (0)