12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import os
15
16
from typing import Any , AsyncIterator , ClassVar , Dict , Iterator , Optional , Union
16
17
17
18
import erniebot .errors as errors
19
+ import erniebot .utils .logging as logging
18
20
from erniebot .api_types import APIType
19
21
from erniebot .response import EBResponse
20
22
from erniebot .types import FilesType , HeadersType , ParamsType
@@ -31,6 +33,10 @@ class CustomBackend(EBBackend):
31
33
32
34
def __init__ (self , config_dict : Dict [str , Any ]) -> None :
33
35
super ().__init__ (config_dict = config_dict )
36
+ access_token = self ._cfg .get ("access_token" , None )
37
+ if access_token is None :
38
+ access_token = os .environ .get ("AISTUDIO_ACCESS_TOKEN" , None )
39
+ self ._access_token = access_token
34
40
35
41
def request (
36
42
self ,
@@ -79,7 +85,8 @@ async def arequest(
79
85
params = params ,
80
86
files = files ,
81
87
)
82
-
88
+ if self ._access_token is not None :
89
+ headers = self ._add_aistudio_fields_to_headers (headers )
83
90
return await self ._client .asend_request (
84
91
method ,
85
92
url ,
@@ -110,3 +117,12 @@ def handle_response(self, resp: EBResponse) -> EBResponse:
110
117
raise errors .APIError (emsg , ecode = ecode )
111
118
else :
112
119
return resp
120
+
121
+ def _add_aistudio_fields_to_headers (self , headers : HeadersType ) -> HeadersType :
122
+ if "Authorization" in headers :
123
+ logging .warning (
124
+ "Key 'Authorization' already exists in `headers`: %r" ,
125
+ headers ["Authorization" ],
126
+ )
127
+ headers ["Authorization" ] = f"{ self ._access_token } "
128
+ return headers
0 commit comments