Skip to content

Commit 3334cbb

Browse files
authored
Eb cherry (#322)
* resolve conflicts * remove unused comments * resolve conflicts * reformat * resolve conflicts * remove lines * resolve conflicts * Add ernieb speed * resolve conflicts * resolve conflicts * Update * restore handle response * Update format * Update http_client * Update format
1 parent 5270b98 commit 3334cbb

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

erniebot/src/erniebot/backends/bce.py

-1
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,6 @@ def handle_response(self, resp: EBResponse) -> EBResponse:
351351
if "error_code" in resp and "error_msg" in resp:
352352
ecode = resp["error_code"]
353353
emsg = resp["error_msg"]
354-
print(ecode)
355354
if ecode in (4, 17):
356355
raise errors.RequestLimitError(emsg, ecode=ecode)
357356
elif ecode in (13, 15, 18):

erniebot/src/erniebot/backends/custom.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, Optional, Union
1617

1718
import erniebot.errors as errors
19+
import erniebot.utils.logging as logging
1820
from erniebot.api_types import APIType
1921
from erniebot.response import EBResponse
2022
from erniebot.types import FilesType, HeadersType, ParamsType
@@ -31,6 +33,10 @@ class CustomBackend(EBBackend):
3133

3234
def __init__(self, config_dict: Dict[str, Any]) -> None:
3335
super().__init__(config_dict=config_dict)
36+
access_token = self._cfg.get("access_token", None)
37+
if access_token is None:
38+
access_token = os.environ.get("AISTUDIO_ACCESS_TOKEN", None)
39+
self._access_token = access_token
3440

3541
def request(
3642
self,
@@ -79,7 +85,8 @@ async def arequest(
7985
params=params,
8086
files=files,
8187
)
82-
88+
if self._access_token is not None:
89+
headers = self._add_aistudio_fields_to_headers(headers)
8390
return await self._client.asend_request(
8491
method,
8592
url,
@@ -110,3 +117,12 @@ def handle_response(self, resp: EBResponse) -> EBResponse:
110117
raise errors.APIError(emsg, ecode=ecode)
111118
else:
112119
return resp
120+
121+
def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
122+
if "Authorization" in headers:
123+
logging.warning(
124+
"Key 'Authorization' already exists in `headers`: %r",
125+
headers["Authorization"],
126+
)
127+
headers["Authorization"] = f"{self._access_token}"
128+
return headers

erniebot/src/erniebot/resources/chat_completion.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
9292
"ernie-3.5": {
9393
"model_id": "completions",
9494
},
95+
"ernie-4.0": {
96+
"model_id": "completions_pro",
97+
},
98+
"ernie-longtext": {
99+
"model_id": "ernie_bot_8k",
100+
},
101+
"ernie-speed": {
102+
"model_id": "ernie_speed",
103+
},
95104
},
96105
},
97106
}
@@ -512,8 +521,11 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
512521
if "extra_params" in kwargs:
513522
params.update(kwargs["extra_params"])
514523

515-
# headers
516-
headers = kwargs.get("headers", None)
524+
headers: HeadersType = {}
525+
if self.api_type is APIType.AISTUDIO or self.api_type is APIType.CUSTOM:
526+
headers["Content-Type"] = "application/json"
527+
if "headers" in kwargs:
528+
headers.update(kwargs["headers"])
517529

518530
# request_timeout
519531
request_timeout = kwargs.get("request_timeout", None)

0 commit comments

Comments
 (0)