@@ -156,6 +156,7 @@ def create(
156
156
extra_params : Optional [dict ] = ...,
157
157
headers : Optional [HeadersType ] = ...,
158
158
request_timeout : Optional [float ] = ...,
159
+ max_output_tokens : Optional [int ] = ...,
159
160
_config_ : Optional [ConfigDictType ] = ...,
160
161
) -> "ChatCompletionResponse" :
161
162
...
@@ -182,6 +183,7 @@ def create(
182
183
extra_params : Optional [dict ] = ...,
183
184
headers : Optional [HeadersType ] = ...,
184
185
request_timeout : Optional [float ] = ...,
186
+ max_output_tokens : Optional [int ] = ...,
185
187
_config_ : Optional [ConfigDictType ] = ...,
186
188
) -> Iterator ["ChatCompletionResponse" ]:
187
189
...
@@ -208,6 +210,7 @@ def create(
208
210
extra_params : Optional [dict ] = ...,
209
211
headers : Optional [HeadersType ] = ...,
210
212
request_timeout : Optional [float ] = ...,
213
+ max_output_tokens : Optional [int ] = ...,
211
214
_config_ : Optional [ConfigDictType ] = ...,
212
215
) -> Union ["ChatCompletionResponse" , Iterator ["ChatCompletionResponse" ]]:
213
216
...
@@ -233,6 +236,7 @@ def create(
233
236
extra_params : Optional [dict ] = None ,
234
237
headers : Optional [HeadersType ] = None ,
235
238
request_timeout : Optional [float ] = None ,
239
+ max_output_tokens : Optional [int ] = None ,
236
240
_config_ : Optional [ConfigDictType ] = None ,
237
241
) -> Union ["ChatCompletionResponse" , Iterator ["ChatCompletionResponse" ]]:
238
242
"""Creates a model response for the given conversation.
@@ -279,6 +283,7 @@ def create(
279
283
user_id = user_id ,
280
284
tool_choice = tool_choice ,
281
285
stream = stream ,
286
+ max_output_tokens = max_output_tokens ,
282
287
)
283
288
kwargs ["validate_functions" ] = validate_functions
284
289
if extra_params is not None :
@@ -313,6 +318,7 @@ async def acreate(
313
318
extra_params : Optional [dict ] = ...,
314
319
headers : Optional [HeadersType ] = ...,
315
320
request_timeout : Optional [float ] = ...,
321
+ max_output_tokens : Optional [int ] = ...,
316
322
_config_ : Optional [ConfigDictType ] = ...,
317
323
) -> EBResponse :
318
324
...
@@ -339,6 +345,7 @@ async def acreate(
339
345
extra_params : Optional [dict ] = ...,
340
346
headers : Optional [HeadersType ] = ...,
341
347
request_timeout : Optional [float ] = ...,
348
+ max_output_tokens : Optional [int ] = ...,
342
349
_config_ : Optional [ConfigDictType ] = ...,
343
350
) -> AsyncIterator ["ChatCompletionResponse" ]:
344
351
...
@@ -365,6 +372,7 @@ async def acreate(
365
372
extra_params : Optional [dict ] = ...,
366
373
headers : Optional [HeadersType ] = ...,
367
374
request_timeout : Optional [float ] = ...,
375
+ max_output_tokens : Optional [int ] = ...,
368
376
_config_ : Optional [ConfigDictType ] = ...,
369
377
) -> Union ["ChatCompletionResponse" , AsyncIterator ["ChatCompletionResponse" ]]:
370
378
...
@@ -390,6 +398,7 @@ async def acreate(
390
398
extra_params : Optional [dict ] = None ,
391
399
headers : Optional [HeadersType ] = None ,
392
400
request_timeout : Optional [float ] = None ,
401
+ max_output_tokens : Optional [int ] = None ,
393
402
_config_ : Optional [ConfigDictType ] = None ,
394
403
) -> Union ["ChatCompletionResponse" , AsyncIterator ["ChatCompletionResponse" ]]:
395
404
"""Creates a model response for the given conversation.
@@ -436,6 +445,7 @@ async def acreate(
436
445
user_id = user_id ,
437
446
tool_choice = tool_choice ,
438
447
stream = stream ,
448
+ max_output_tokens = max_output_tokens ,
439
449
)
440
450
kwargs ["validate_functions" ] = validate_functions
441
451
if extra_params is not None :
@@ -450,12 +460,7 @@ async def acreate(
450
460
451
461
def _check_model_kwargs (self , model_name : str , kwargs : Dict [str , Any ]) -> None :
452
462
if model_name in ("ernie-speed" , "ernie-speed-128k" , "ernie-char-8k" , "ernie-tiny-8k" , "ernie-lite" ):
453
- for arg in (
454
- "functions" ,
455
- "disable_search" ,
456
- "enable_citation" ,
457
- "tool_choice" ,
458
- ):
463
+ for arg in ("functions" , "disable_search" , "enable_citation" , "tool_choice" ):
459
464
if arg in kwargs :
460
465
raise errors .InvalidArgumentError (f"`{ arg } ` is not supported by the { model_name } model." )
461
466
@@ -492,6 +497,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
492
497
"extra_params" ,
493
498
"headers" ,
494
499
"request_timeout" ,
500
+ "max_output_tokens" ,
495
501
}
496
502
497
503
invalid_keys = kwargs .keys () - valid_keys
@@ -554,6 +560,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
554
560
_set_val_if_key_exists (kwargs , params , "user_id" )
555
561
_set_val_if_key_exists (kwargs , params , "tool_choice" )
556
562
_set_val_if_key_exists (kwargs , params , "stream" )
563
+ _set_val_if_key_exists (kwargs , params , "max_output_tokens" )
564
+
557
565
if "extra_params" in kwargs :
558
566
params .update (kwargs ["extra_params" ])
559
567
0 commit comments