Skip to content

Commit 2bfa1b1

Browse files
authored
[aistudio api] Update weipu api (#319)
* Update weipu api * Updare erniebot api * remove unused comments * restore erniebot * reformat * update * remove lines * fix ci * Add ernieb speed * suport no access token config * Fix unitest
1 parent 4aca819 commit 2bfa1b1

File tree

5 files changed

+28
-4
lines changed

5 files changed

+28
-4
lines changed

erniebot-agent/tests/unit_tests/tools/test_llama_index_retrieval_tool.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from llama_index.schema import NodeWithScore, TextNode
2+
from llama_index.core.schema import NodeWithScore, TextNode
33

44
from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool
55

erniebot/src/erniebot/backends/bce.py

-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,6 @@ def handle_response(cls, resp: EBResponse) -> EBResponse:
349349
if "error_code" in resp and "error_msg" in resp:
350350
ecode = resp["error_code"]
351351
emsg = resp["error_msg"]
352-
print(ecode)
353352
if ecode in (4, 17):
354353
raise errors.RequestLimitError(emsg, ecode=ecode)
355354
elif ecode in (13, 15, 18):

erniebot/src/erniebot/backends/custom.py

+17
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, Optional, Union
1617

18+
import erniebot.utils.logging as logging
1719
from erniebot.api_types import APIType
1820
from erniebot.backends.bce import QianfanLegacyBackend
1921
from erniebot.response import EBResponse
@@ -29,6 +31,10 @@ class CustomBackend(EBBackend):
2931

3032
def __init__(self, config_dict: Dict[str, Any]) -> None:
3133
super().__init__(config_dict=config_dict)
34+
access_token = self._cfg.get("access_token", None)
35+
if access_token is None:
36+
access_token = os.environ.get("AISTUDIO_ACCESS_TOKEN", None)
37+
self._access_token = access_token
3238

3339
def request(
3440
self,
@@ -71,6 +77,8 @@ async def arequest(
7177
supplied_headers=headers,
7278
params=params,
7379
)
80+
if self._access_token is not None:
81+
headers = self._add_aistudio_fields_to_headers(headers)
7482
return await self._client.asend_request(
7583
method,
7684
url,
@@ -83,3 +91,12 @@ async def arequest(
8391
@classmethod
8492
def handle_response(cls, resp: EBResponse) -> EBResponse:
8593
return QianfanLegacyBackend.handle_response(resp)
94+
95+
def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
96+
if "Authorization" in headers:
97+
logging.warning(
98+
"Key 'Authorization' already exists in `headers`: %r",
99+
headers["Authorization"],
100+
)
101+
headers["Authorization"] = f"{self._access_token}"
102+
return headers

erniebot/src/erniebot/http_client.py

-1
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ def _interpret_response_line(
411411
logging.debug("Decoded response body: %r", decoded_rbody)
412412

413413
response = EBResponse(rcode=rcode, rbody=decoded_rbody, rheaders=dict(rheaders))
414-
415414
if rcode != http.HTTPStatus.OK:
416415
raise errors.HTTPRequestError(
417416
f"The status code is not {http.HTTPStatus.OK}.",

erniebot/src/erniebot/resources/chat_completion.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
9292
"ernie-3.5": {
9393
"model_id": "completions",
9494
},
95+
"ernie-4.0": {
96+
"model_id": "completions_pro",
97+
},
98+
"ernie-longtext": {
99+
"model_id": "ernie_bot_8k",
100+
},
101+
"ernie-speed": {
102+
"model_id": "ernie_speed",
103+
},
95104
},
96105
},
97106
}
@@ -514,7 +523,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
514523

515524
# headers
516525
headers: HeadersType = {}
517-
if self.api_type is APIType.AISTUDIO:
526+
if self.api_type is APIType.AISTUDIO or self.api_type is APIType.CUSTOM:
518527
headers["Content-Type"] = "application/json"
519528
if "headers" in kwargs:
520529
headers.update(kwargs["headers"])

0 commit comments

Comments
 (0)