Skip to content

Commit 621ce45

Browse files
authored
[API ADD]Add max_out_token (#343)
* add `max_out_token` * fix flake
1 parent 8de8fdb commit 621ce45

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

erniebot/src/erniebot/resources/chat_completion.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def create(
156156
extra_params: Optional[dict] = ...,
157157
headers: Optional[HeadersType] = ...,
158158
request_timeout: Optional[float] = ...,
159+
max_output_tokens: Optional[int] = ...,
159160
_config_: Optional[ConfigDictType] = ...,
160161
) -> "ChatCompletionResponse":
161162
...
@@ -182,6 +183,7 @@ def create(
182183
extra_params: Optional[dict] = ...,
183184
headers: Optional[HeadersType] = ...,
184185
request_timeout: Optional[float] = ...,
186+
max_output_tokens: Optional[int] = ...,
185187
_config_: Optional[ConfigDictType] = ...,
186188
) -> Iterator["ChatCompletionResponse"]:
187189
...
@@ -208,6 +210,7 @@ def create(
208210
extra_params: Optional[dict] = ...,
209211
headers: Optional[HeadersType] = ...,
210212
request_timeout: Optional[float] = ...,
213+
max_output_tokens: Optional[int] = ...,
211214
_config_: Optional[ConfigDictType] = ...,
212215
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
213216
...
@@ -233,6 +236,7 @@ def create(
233236
extra_params: Optional[dict] = None,
234237
headers: Optional[HeadersType] = None,
235238
request_timeout: Optional[float] = None,
239+
max_output_tokens: Optional[int] = None,
236240
_config_: Optional[ConfigDictType] = None,
237241
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
238242
"""Creates a model response for the given conversation.
@@ -279,6 +283,7 @@ def create(
279283
user_id=user_id,
280284
tool_choice=tool_choice,
281285
stream=stream,
286+
max_output_tokens=max_output_tokens,
282287
)
283288
kwargs["validate_functions"] = validate_functions
284289
if extra_params is not None:
@@ -313,6 +318,7 @@ async def acreate(
313318
extra_params: Optional[dict] = ...,
314319
headers: Optional[HeadersType] = ...,
315320
request_timeout: Optional[float] = ...,
321+
max_output_tokens: Optional[int] = ...,
316322
_config_: Optional[ConfigDictType] = ...,
317323
) -> EBResponse:
318324
...
@@ -339,6 +345,7 @@ async def acreate(
339345
extra_params: Optional[dict] = ...,
340346
headers: Optional[HeadersType] = ...,
341347
request_timeout: Optional[float] = ...,
348+
max_output_tokens: Optional[int] = ...,
342349
_config_: Optional[ConfigDictType] = ...,
343350
) -> AsyncIterator["ChatCompletionResponse"]:
344351
...
@@ -365,6 +372,7 @@ async def acreate(
365372
extra_params: Optional[dict] = ...,
366373
headers: Optional[HeadersType] = ...,
367374
request_timeout: Optional[float] = ...,
375+
max_output_tokens: Optional[int] = ...,
368376
_config_: Optional[ConfigDictType] = ...,
369377
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
370378
...
@@ -390,6 +398,7 @@ async def acreate(
390398
extra_params: Optional[dict] = None,
391399
headers: Optional[HeadersType] = None,
392400
request_timeout: Optional[float] = None,
401+
max_output_tokens: Optional[int] = None,
393402
_config_: Optional[ConfigDictType] = None,
394403
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
395404
"""Creates a model response for the given conversation.
@@ -436,6 +445,7 @@ async def acreate(
436445
user_id=user_id,
437446
tool_choice=tool_choice,
438447
stream=stream,
448+
max_output_tokens=max_output_tokens,
439449
)
440450
kwargs["validate_functions"] = validate_functions
441451
if extra_params is not None:
@@ -450,12 +460,7 @@ async def acreate(
450460

451461
def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None:
452462
if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k", "ernie-lite"):
453-
for arg in (
454-
"functions",
455-
"disable_search",
456-
"enable_citation",
457-
"tool_choice",
458-
):
463+
for arg in ("functions", "disable_search", "enable_citation", "tool_choice"):
459464
if arg in kwargs:
460465
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")
461466

@@ -492,6 +497,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
492497
"extra_params",
493498
"headers",
494499
"request_timeout",
500+
"max_output_tokens",
495501
}
496502

497503
invalid_keys = kwargs.keys() - valid_keys
@@ -554,6 +560,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
554560
_set_val_if_key_exists(kwargs, params, "user_id")
555561
_set_val_if_key_exists(kwargs, params, "tool_choice")
556562
_set_val_if_key_exists(kwargs, params, "stream")
563+
_set_val_if_key_exists(kwargs, params, "max_output_tokens")
564+
557565
if "extra_params" in kwargs:
558566
params.update(kwargs["extra_params"])
559567

0 commit comments

Comments
 (0)