12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import os
15
16
from typing import Any , AsyncIterator , ClassVar , Dict , Iterator , Optional , Union
16
17
18
+ import erniebot .utils .logging as logging
17
19
from erniebot .api_types import APIType
18
20
from erniebot .backends .bce import QianfanLegacyBackend
19
21
from erniebot .response import EBResponse
@@ -29,6 +31,10 @@ class CustomBackend(EBBackend):
29
31
30
32
def __init__ (self , config_dict : Dict [str , Any ]) -> None :
31
33
super ().__init__ (config_dict = config_dict )
34
+ access_token = self ._cfg .get ("access_token" , None )
35
+ if access_token is None :
36
+ access_token = os .environ .get ("AISTUDIO_ACCESS_TOKEN" , None )
37
+ self ._access_token = access_token
32
38
33
39
def request (
34
40
self ,
@@ -71,6 +77,8 @@ async def arequest(
71
77
supplied_headers = headers ,
72
78
params = params ,
73
79
)
80
+ if self ._access_token is not None :
81
+ headers = self ._add_aistudio_fields_to_headers (headers )
74
82
return await self ._client .asend_request (
75
83
method ,
76
84
url ,
@@ -83,3 +91,12 @@ async def arequest(
83
91
@classmethod
84
92
def handle_response (cls , resp : EBResponse ) -> EBResponse :
85
93
return QianfanLegacyBackend .handle_response (resp )
94
+
95
+ def _add_aistudio_fields_to_headers (self , headers : HeadersType ) -> HeadersType :
96
+ if "Authorization" in headers :
97
+ logging .warning (
98
+ "Key 'Authorization' already exists in `headers`: %r" ,
99
+ headers ["Authorization" ],
100
+ )
101
+ headers ["Authorization" ] = f"{ self ._access_token } "
102
+ return headers
0 commit comments