Skip to content

Commit 39f0a38

Browse files
authoredMay 31, 2024··
[erniebot] Add new qianfan models, Add response_format argument (#349)
* add new models; add response_format * remove aksk * fix lint
1 parent 3f2ecc0 commit 39f0a38

File tree

2 files changed

+47
-10
lines changed

2 files changed

+47
-10
lines changed
 

‎erniebot/src/erniebot/resources/chat_completion.py

+34-9
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,29 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
5858
"ernie-3.5-8k": {
5959
"model_id": "completions",
6060
},
61+
"ernie-3.5-8k-0205": {
62+
"model_id": "ernie-3.5-8k-0205",
63+
},
64+
"ernie-3.5-8k-0329": {
65+
"model_id": "ernie-3.5-8k-0329",
66+
},
67+
"ernie-3.5-128k": {
68+
"model_id": "ernie-3.5-128k",
69+
},
6170
"ernie-lite": {
6271
"model_id": "eb-instant",
6372
},
73+
"ernie-lite-8k-0308": {
74+
"model_id": "ernie-lite-8k",
75+
},
6476
"ernie-4.0": {
6577
"model_id": "completions_pro",
6678
},
67-
"ernie-longtext": {
68-
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
69-
"model_id": "completions",
79+
"ernie-4.0-8k-0329": {
80+
"model_id": "ernie-4.0-8k-0329",
81+
},
82+
"ernie-4.0-8k-0104": {
83+
"model_id": "ernie-4.0-8k-0104",
7084
},
7185
"ernie-speed": {
7286
"model_id": "ernie_speed",
@@ -97,10 +111,6 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
97111
"ernie-4.0": {
98112
"model_id": "completions_pro",
99113
},
100-
"ernie-longtext": {
101-
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
102-
"model_id": "completions",
103-
},
104114
"ernie-speed": {
105115
"model_id": "ernie_speed",
106116
},
@@ -156,6 +166,7 @@ def create(
156166
extra_params: Optional[dict] = ...,
157167
headers: Optional[HeadersType] = ...,
158168
request_timeout: Optional[float] = ...,
169+
response_format: Optional[Literal["json_object", "text"]] = ...,
159170
max_output_tokens: Optional[int] = ...,
160171
_config_: Optional[ConfigDictType] = ...,
161172
) -> "ChatCompletionResponse":
@@ -183,6 +194,7 @@ def create(
183194
extra_params: Optional[dict] = ...,
184195
headers: Optional[HeadersType] = ...,
185196
request_timeout: Optional[float] = ...,
197+
response_format: Optional[Literal["json_object", "text"]] = ...,
186198
max_output_tokens: Optional[int] = ...,
187199
_config_: Optional[ConfigDictType] = ...,
188200
) -> Iterator["ChatCompletionResponse"]:
@@ -210,6 +222,7 @@ def create(
210222
extra_params: Optional[dict] = ...,
211223
headers: Optional[HeadersType] = ...,
212224
request_timeout: Optional[float] = ...,
225+
response_format: Optional[Literal["json_object", "text"]] = ...,
213226
max_output_tokens: Optional[int] = ...,
214227
_config_: Optional[ConfigDictType] = ...,
215228
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
@@ -236,6 +249,7 @@ def create(
236249
extra_params: Optional[dict] = None,
237250
headers: Optional[HeadersType] = None,
238251
request_timeout: Optional[float] = None,
252+
response_format: Optional[Literal["json_object", "text"]] = None,
239253
max_output_tokens: Optional[int] = None,
240254
_config_: Optional[ConfigDictType] = None,
241255
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
@@ -292,6 +306,8 @@ def create(
292306
kwargs["headers"] = headers
293307
if request_timeout is not None:
294308
kwargs["request_timeout"] = request_timeout
309+
if response_format is not None:
310+
kwargs["response_format"] = response_format
295311

296312
resp = resource.create_resource(**kwargs)
297313
return transform(ChatCompletionResponse.from_mapping, resp)
@@ -318,6 +334,7 @@ async def acreate(
318334
extra_params: Optional[dict] = ...,
319335
headers: Optional[HeadersType] = ...,
320336
request_timeout: Optional[float] = ...,
337+
response_format: Optional[Literal["json_object", "text"]] = ...,
321338
max_output_tokens: Optional[int] = ...,
322339
_config_: Optional[ConfigDictType] = ...,
323340
) -> EBResponse:
@@ -345,6 +362,7 @@ async def acreate(
345362
extra_params: Optional[dict] = ...,
346363
headers: Optional[HeadersType] = ...,
347364
request_timeout: Optional[float] = ...,
365+
response_format: Optional[Literal["json_object", "text"]] = ...,
348366
max_output_tokens: Optional[int] = ...,
349367
_config_: Optional[ConfigDictType] = ...,
350368
) -> AsyncIterator["ChatCompletionResponse"]:
@@ -372,6 +390,7 @@ async def acreate(
372390
extra_params: Optional[dict] = ...,
373391
headers: Optional[HeadersType] = ...,
374392
request_timeout: Optional[float] = ...,
393+
response_format: Optional[Literal["json_object", "text"]] = ...,
375394
max_output_tokens: Optional[int] = ...,
376395
_config_: Optional[ConfigDictType] = ...,
377396
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
@@ -398,6 +417,7 @@ async def acreate(
398417
extra_params: Optional[dict] = None,
399418
headers: Optional[HeadersType] = None,
400419
request_timeout: Optional[float] = None,
420+
response_format: Optional[Literal["json_object", "text"]] = None,
401421
max_output_tokens: Optional[int] = None,
402422
_config_: Optional[ConfigDictType] = None,
403423
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
@@ -423,6 +443,7 @@ async def acreate(
423443
validate_functions: Whether to validate the function descriptions.
424444
headers: Custom headers to send with the request.
425445
request_timeout: Timeout for a single request.
446+
response_format: Format of the response.
426447
_config_: Overrides the global settings.
427448
428449
Returns:
@@ -460,9 +481,11 @@ async def acreate(
460481

461482
def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None:
462483
if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k", "ernie-lite"):
463-
for arg in ("functions", "disable_search", "enable_citation", "tool_choice"):
484+
for arg in ("functions", "disable_search", "enable_citation", "tool_choice", "response_format"):
464485
if arg in kwargs:
465-
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")
486+
raise errors.InvalidArgumentError(
487+
f"`{arg}` is not supported by the `{model_name}` model."
488+
)
466489

467490
def _prepare_create(self, kwargs: Dict[str, Any]) -> RequestWithStream:
468491
def _update_model_name(given_name: str, old_name_to_new_name: Dict[str, str]) -> str:
@@ -497,6 +520,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
497520
"extra_params",
498521
"headers",
499522
"request_timeout",
523+
"response_format",
500524
"max_output_tokens",
501525
}
502526

@@ -561,6 +585,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
561585
_set_val_if_key_exists(kwargs, params, "tool_choice")
562586
_set_val_if_key_exists(kwargs, params, "stream")
563587
_set_val_if_key_exists(kwargs, params, "max_output_tokens")
588+
_set_val_if_key_exists(kwargs, params, "response_format")
564589

565590
if "extra_params" in kwargs:
566591
params.update(kwargs["extra_params"])

‎erniebot/tests/test_chat_completion.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ def create_chat_completion(model):
3939
print(response.get_result())
4040

4141

42+
def create_chat_completion_json_mode(model):
43+
response = erniebot.ChatCompletion.create(
44+
model=model,
45+
messages=[
46+
{"role": "user", "content": "文心一言是哪个公司开发的?"},
47+
],
48+
stream=False,
49+
response_format="json_object",
50+
)
51+
print(response.get_result())
52+
53+
4254
def create_chat_completion_stream(model):
4355
response = erniebot.ChatCompletion.create(
4456
model=model,
@@ -68,5 +80,5 @@ def create_chat_completion_stream(model):
6880
erniebot.api_type = "qianfan"
6981

7082
create_chat_completion(model="ernie-turbo")
71-
7283
create_chat_completion_stream(model="ernie-turbo")
84+
create_chat_completion_json_mode(model="ernie-lite")

0 commit comments

Comments
 (0)
Please sign in to comment.