Skip to content

Commit ec9d29a

Browse files
committed
cherry-pick 3d7e1ff
1 parent a58eca7 commit ec9d29a

File tree

2 files changed

+75
-7
lines changed

2 files changed

+75
-7
lines changed

erniebot/src/erniebot/intro.py

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def list() -> List[Tuple[str, str]]:
2828
("ernie-turbo", "文心大模型(ernie-turbo)"),
2929
("ernie-4.0", "文心大模型(ernie-4.0)"),
3030
("ernie-longtext", "文心大模型(ernie-longtext)"),
31+
("ernie-speed", " 文心大模型(ernie-speed)"),
32+
("ernie-speed-128k", " 文心大模型(ernie-speed-128k)"),
33+
("ernie-tiny-8k", " 文心大模型(ernie-tiny-8k)"),
34+
("ernie-char-8k", " 文心大模型(ernie-char-8k)"),
3135
("ernie-text-embedding", "文心百中语义模型"),
3236
("ernie-vilg-v2", "文心一格模型"),
3337
]

erniebot/src/erniebot/resources/chat_completion.py

+71-7
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,30 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
5555
"ernie-3.5": {
5656
"model_id": "completions",
5757
},
58+
"ernie-3.5-8k": {
59+
"model_id": "completions",
60+
},
5861
"ernie-turbo": {
5962
"model_id": "eb-instant",
6063
},
6164
"ernie-4.0": {
6265
"model_id": "completions_pro",
6366
},
6467
"ernie-longtext": {
65-
"model_id": "ernie_bot_8k",
68+
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
69+
"model_id": "completions",
70+
},
71+
"ernie-speed": {
72+
"model_id": "ernie_speed",
73+
},
74+
"ernie-speed-128k": {
75+
"model_id": "ernie-speed-128k",
76+
},
77+
"ernie-tiny-8k": {
78+
"model_id": "ernie-tiny-8k",
79+
},
80+
"ernie-char-8k": {
81+
"model_id": "ernie-char-8k",
6682
},
6783
},
6884
},
@@ -72,14 +88,30 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
7288
"ernie-3.5": {
7389
"model_id": "completions",
7490
},
91+
"ernie-3.5-8k": {
92+
"model_id": "completions",
93+
},
7594
"ernie-turbo": {
7695
"model_id": "eb-instant",
7796
},
7897
"ernie-4.0": {
7998
"model_id": "completions_pro",
8099
},
81100
"ernie-longtext": {
82-
"model_id": "ernie_bot_8k",
101+
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
102+
"model_id": "completions",
103+
},
104+
"ernie-speed": {
105+
"model_id": "ernie_speed",
106+
},
107+
"ernie-speed-128k": {
108+
"model_id": "ernie-speed-128k",
109+
},
110+
"ernie-tiny-8k": {
111+
"model_id": "ernie-tiny-8k",
112+
},
113+
"ernie-char-8k": {
114+
"model_id": "ernie-char-8k",
83115
},
84116
},
85117
},
@@ -89,6 +121,15 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
89121
"ernie-3.5": {
90122
"model_id": "completions",
91123
},
124+
"ernie-4.0": {
125+
"model_id": "completions_pro",
126+
},
127+
"ernie-longtext": {
128+
"model_id": "completions",
129+
},
130+
"ernie-speed": {
131+
"model_id": "ernie_speed",
132+
},
92133
},
93134
},
94135
}
@@ -251,6 +292,7 @@ def create(
251292
kwargs["headers"] = headers
252293
if request_timeout is not None:
253294
kwargs["request_timeout"] = request_timeout
295+
254296
resp = resource.create_resource(**kwargs)
255297
return transform(ChatCompletionResponse.from_mapping, resp)
256298

@@ -412,9 +454,32 @@ async def acreate(
412454
kwargs["headers"] = headers
413455
if request_timeout is not None:
414456
kwargs["request_timeout"] = request_timeout
457+
415458
resp = await resource.acreate_resource(**kwargs)
416459
return transform(ChatCompletionResponse.from_mapping, resp)
417460

461+
def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None:
462+
if model_name in ("ernie-turbo",):
463+
for arg in (
464+
"functions",
465+
"stop",
466+
"disable_search",
467+
"enable_citation",
468+
"tool_choice",
469+
):
470+
if arg in kwargs:
471+
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")
472+
473+
if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k"):
474+
for arg in (
475+
"functions",
476+
"disable_search",
477+
"enable_citation",
478+
"tool_choice",
479+
):
480+
if arg in kwargs:
481+
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")
482+
418483
def _prepare_create(self, kwargs: Dict[str, Any]) -> RequestWithStream:
419484
def _update_model_name(given_name: str, old_name_to_new_name: Dict[str, str]) -> str:
420485
if given_name in old_name_to_new_name:
@@ -467,7 +532,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
467532
"ernie-bot": "ernie-3.5",
468533
"ernie-bot-turbo": "ernie-turbo",
469534
"ernie-bot-4": "ernie-4.0",
470-
"ernie-bot-8k": "ernie-longtext",
535+
"ernie-bot-8k": "ernie-3.5-8k",
536+
"ernie-longtext": "ernie-3.5-8k",
471537
},
472538
)
473539

@@ -489,10 +555,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
489555

490556
# params
491557
params = {}
492-
if model == "ernie-turbo":
493-
for arg in ("functions", "stop", "disable_search", "enable_citation"):
494-
if arg in kwargs:
495-
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model} model.")
558+
self._check_model_kwargs(model, kwargs)
559+
496560
params["messages"] = messages
497561
if "functions" in kwargs:
498562
functions = kwargs["functions"]

0 commit comments

Comments
 (0)