Skip to content

Commit dbf3499

Browse files
authored
[Feat] Move config to kwargs and Give Clearer Error Message on Unsupported API Types
2 parents 18eca30 + ebc9409 commit dbf3499

11 files changed

+79
-54
lines changed

docs/authentication.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ erniebot.ak = "<EB-ACCESS-KEY>"
3737
erniebot.sk = "<EB-SECRET-KEY>"
3838
```
3939

40-
3) 使用`config`参数:
40+
3) 使用`_config_`参数:
4141
``` {.py .copy}
4242
import erniebot
4343

4444
# Take erniebot.ChatCompletion as an example
4545
chat_completion = erniebot.ChatCompletion.create(
46-
config=dict(
46+
_config_=dict(
4747
api_type="<EB-API-TYPE>",
4848
ak="<EB-ACCESS-KEY>",
4949
sk="<EB-SECRET-KEY>",
@@ -58,7 +58,7 @@ chat_completion = erniebot.ChatCompletion.create(
5858

5959
注意事项:
6060

61-
* 允许同时使用多种方式设置鉴权信息,程序将根据设置方式的优先级确定配置项的最终取值。三种设置方式的优先级从高到低依次为:使用`config`参数,使用全局变量,使用环境变量。
61+
* 允许同时使用多种方式设置鉴权信息,程序将根据设置方式的优先级确定配置项的最终取值。三种设置方式的优先级从高到低依次为:使用`_config_`参数,使用全局变量,使用环境变量。
6262
* **使用特定模型,请准确设置对应后端平台的认证鉴权参数。**
6363

6464
## 申请千帆大模型平台的AK/SK

erniebot/resources/abc/creatable.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,37 +26,28 @@ class Creatable(Resource, Protocol):
2626
"""Creatable resource protocol."""
2727

2828
@classmethod
29-
def create(cls,
30-
*,
31-
config: Optional[Dict[str, Any]]=None,
32-
**create_kwargs: Any) -> Union[EBResponse, Iterator[EBResponse]]:
33-
"""Create a new resource.
34-
35-
Args:
36-
config: Configuration dictionary.
37-
**create_kwargs: Parameters for creating the resource.
38-
39-
Returns:
40-
Response from the server.
41-
"""
42-
config = config or {}
29+
def create(cls, **kwargs: Any) -> Union[EBResponse, Iterator[EBResponse]]:
30+
"""Create a resource."""
31+
config = kwargs.pop('_config_', {})
4332
resource = cls.new_object(**config)
33+
create_kwargs = kwargs
4434
return resource.create_resource(**create_kwargs)
4535

4636
@classmethod
4737
async def acreate(
48-
cls, *, config: Optional[Dict[str, Any]]=None,
49-
**create_kwargs: Any) -> Union[EBResponse, AsyncIterator[EBResponse]]:
38+
cls, **kwargs: Any) -> Union[EBResponse, AsyncIterator[EBResponse]]:
5039
"""Asynchronous version of `create`."""
51-
config = config or {}
40+
config = kwargs.pop('_config_', {})
5241
resource = cls.new_object(**config)
42+
create_kwargs = kwargs
5343
resp = await resource.acreate_resource(**create_kwargs)
5444
return resp
5545

5646
def create_resource(
57-
self, **kwargs: Any) -> Union[EBResponse, Iterator[EBResponse]]:
47+
self,
48+
**create_kwargs: Any) -> Union[EBResponse, Iterator[EBResponse]]:
5849
url, params, headers, files, stream, request_timeout = self._prepare_create(
59-
kwargs)
50+
create_kwargs)
6051
resp = self.request(
6152
method='POST',
6253
url=url,
@@ -70,9 +61,10 @@ def create_resource(
7061
return resp
7162

7263
async def acreate_resource(
73-
self, **kwargs: Any) -> Union[EBResponse, AsyncIterator[EBResponse]]:
64+
self,
65+
**create_kwargs: Any) -> Union[EBResponse, AsyncIterator[EBResponse]]:
7466
url, params, headers, files, stream, request_timeout = self._prepare_create(
75-
kwargs)
67+
create_kwargs)
7668
resp = await self.arequest(
7769
method='POST',
7870
url=url,

erniebot/resources/abc/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Resource(Protocol):
2929
"""
3030

3131
@classmethod
32-
def new_object(cls, **kwargs: Any) -> Self:
32+
def new_object(cls, **config: Any) -> Self:
3333
...
3434

3535
def request(

erniebot/resources/chat_completion.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
class ChatCompletion(EBResource, Creatable):
2525
"""Given a conversation, get a new reply from the model."""
2626

27+
SUPPORTED_API_TYPES: ClassVar[Tuple[APIType, ...]] = (APIType.QIANFAN,
28+
APIType.AI_STUDIO)
2729
_API_INFO_DICT: ClassVar[Dict[APIType, Dict[str, Any]]] = {
2830
APIType.QIANFAN: {
2931
'prefix': 'chat',
@@ -85,14 +87,15 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
8587
messages = kwargs['messages']
8688

8789
# url
88-
if self.api_type in self._API_INFO_DICT:
90+
if self.api_type in self.SUPPORTED_API_TYPES:
8991
api_info = self._API_INFO_DICT[self.api_type]
9092
if model not in api_info['models']:
9193
raise errors.InvalidArgumentError(
9294
f"{repr(model)} is not a supported model.")
9395
url = f"/{api_info['prefix']}/{api_info['models'][model]['suffix']}"
9496
else:
95-
raise errors.UnsupportedAPITypeError
97+
raise errors.UnsupportedAPITypeError(
98+
f"Supported API types: {self.get_supported_api_type_names()}")
9699

97100
# params
98101
params = {}

erniebot/resources/chat_file.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
class ChatFile(EBResource, Creatable):
2525
"""Chat with the model about the content of a given file."""
2626

27+
SUPPORTED_API_TYPES: ClassVar[Tuple[APIType, ...]] = (APIType.QIANFAN, )
28+
2729
def _prepare_create(self,
2830
kwargs: Dict[str, Any]) -> Tuple[str,
2931
Optional[ParamsType],
@@ -46,10 +48,12 @@ def _prepare_create(self,
4648
messages = kwargs['messages']
4749

4850
# url
51+
assert self.SUPPORTED_API_TYPES == (APIType.QIANFAN, )
4952
if self.api_type is APIType.QIANFAN:
5053
url = "/chat/chatfile_adv"
5154
else:
52-
raise errors.UnsupportedAPITypeError
55+
raise errors.UnsupportedAPITypeError(
56+
f"Supported API types: {self.get_supported_api_type_names()}")
5357

5458
# params
5559
params = {}

erniebot/resources/embedding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
class Embedding(EBResource, Creatable):
2525
"""Get the embeddings of a given text input."""
2626

27+
SUPPORTED_API_TYPES: ClassVar[Tuple[APIType, ...]] = (APIType.QIANFAN,
28+
APIType.AI_STUDIO)
2729
_API_INFO_DICT: ClassVar[Dict[APIType, Dict[str, Any]]] = {
2830
APIType.QIANFAN: {
2931
'prefix': 'embeddings',
@@ -74,14 +76,15 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
7476
input = kwargs['input']
7577

7678
# url
77-
if self.api_type in self._API_INFO_DICT:
79+
if self.api_type in self.SUPPORTED_API_TYPES:
7880
api_info = self._API_INFO_DICT[self.api_type]
7981
if model not in api_info['models']:
8082
raise errors.InvalidArgumentError(
8183
f"{repr(model)} is not a supported model.")
8284
url = f"/{api_info['prefix']}/{api_info['models'][model]['suffix']}"
8385
else:
84-
raise errors.UnsupportedAPITypeError
86+
raise errors.UnsupportedAPITypeError(
87+
f"Supported API types: {self.get_supported_api_type_names()}")
8588

8689
# params
8790
params = {}

erniebot/resources/image.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import (Any, Dict, Optional, Tuple)
15+
from typing import (Any, ClassVar, Dict, Optional, Tuple)
1616

1717
from typing_extensions import TypeAlias
1818

@@ -25,18 +25,25 @@
2525

2626
class _Image(EBResource):
2727
@classmethod
28-
def create(cls, **create_kwargs: Any) -> EBResponse:
29-
resource = cls.new_object()
28+
def create(cls, **kwargs: Any) -> EBResponse:
29+
"""Create a resource."""
30+
config = kwargs.pop('_config_', {})
31+
resource = cls.new_object(**config)
32+
create_kwargs = kwargs
3033
return resource.create_resource(**create_kwargs)
3134

3235
@classmethod
33-
async def acreate(cls, **create_kwargs: Any) -> EBResponse:
34-
resource = cls.new_object()
36+
async def acreate(cls, **kwargs: Any) -> EBResponse:
37+
"""Asynchronous version of `create`."""
38+
config = kwargs.pop('_config_', {})
39+
resource = cls.new_object(**config)
40+
create_kwargs = kwargs
3541
resp = await resource.acreate_resource(**create_kwargs)
3642
return resp
3743

38-
def create_resource(self, **kwargs: Any) -> EBResponse:
39-
url, params, headers, request_timeout = self._prepare_paint(kwargs)
44+
def create_resource(self, **create_kwargs: Any) -> EBResponse:
45+
url, params, headers, request_timeout = self._prepare_paint(
46+
create_kwargs)
4047
resp_p = self.request(
4148
method='POST',
4249
url=url,
@@ -59,8 +66,9 @@ def create_resource(self, **kwargs: Any) -> EBResponse:
5966

6067
return resp_f
6168

62-
async def acreate_resource(self, **kwargs: Any) -> EBResponse:
63-
url, params, headers, request_timeout = self._prepare_paint(kwargs)
69+
async def acreate_resource(self, **create_kwargs: Any) -> EBResponse:
70+
url, params, headers, request_timeout = self._prepare_paint(
71+
create_kwargs)
6472
resp_p = await self.arequest(
6573
method='POST',
6674
url=url,
@@ -105,6 +113,8 @@ def _check_status(resp: EBResponse) -> bool:
105113
class ImageV1(_Image):
106114
"""Generate a new image based on a given prompt."""
107115

116+
SUPPORTED_API_TYPES: ClassVar[Tuple[APIType, ...]] = (APIType.YINIAN, )
117+
108118
def _prepare_paint(self,
109119
kwargs: Dict[str, Any]) -> Tuple[str,
110120
Optional[ParamsType],
@@ -141,10 +151,12 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
141151
style = kwargs['style']
142152

143153
# url
154+
assert self.SUPPORTED_API_TYPES == (APIType.YINIAN, )
144155
if self.api_type is APIType.YINIAN:
145156
url = "/txt2img"
146157
else:
147-
raise errors.UnsupportedAPITypeError
158+
raise errors.UnsupportedAPITypeError(
159+
f"Supported API types: {self.get_supported_api_type_names()}")
148160

149161
# params
150162
params = {}
@@ -166,10 +178,12 @@ def _prepare_fetch(self, resp_p: EBResponse) -> Tuple[str,
166178
Optional[HeadersType],
167179
]:
168180
# url
181+
assert self.SUPPORTED_API_TYPES == (APIType.YINIAN, )
169182
if self.api_type is APIType.YINIAN:
170183
url = "/getImg"
171184
else:
172-
raise errors.UnsupportedAPITypeError
185+
raise errors.UnsupportedAPITypeError(
186+
f"Supported API types: {self.get_supported_api_type_names()}")
173187

174188
# params
175189
params = {}
@@ -189,6 +203,8 @@ def _check_status(resp: EBResponse) -> bool:
189203
class ImageV2(_Image):
190204
"""Generate a new image based on a given prompt."""
191205

206+
SUPPORTED_API_TYPES: ClassVar[Tuple[APIType, ...]] = (APIType.YINIAN, )
207+
192208
def _prepare_paint(self,
193209
kwargs: Dict[str, Any]) -> Tuple[str,
194210
Optional[ParamsType],
@@ -231,13 +247,15 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
231247
height = kwargs['height']
232248

233249
# url
250+
assert self.SUPPORTED_API_TYPES == (APIType.YINIAN, )
234251
if self.api_type is APIType.YINIAN:
235252
url = "/txt2imgv2"
236253
if model != 'ernie-vilg-v2':
237254
raise errors.InvalidArgumentError(
238255
f"{repr(model)} is not a supported model.")
239256
else:
240-
raise errors.UnsupportedAPITypeError
257+
raise errors.UnsupportedAPITypeError(
258+
f"Supported API types: {self.get_supported_api_type_names()}")
241259

242260
# params
243261
params = {}
@@ -260,10 +278,12 @@ def _prepare_fetch(self, resp_p: EBResponse) -> Tuple[str,
260278
Optional[HeadersType],
261279
]:
262280
# url
281+
assert self.SUPPORTED_API_TYPES == (APIType.YINIAN, )
263282
if self.api_type is APIType.YINIAN:
264283
url = "/getImgv2"
265284
else:
266-
raise errors.UnsupportedAPITypeError
285+
raise errors.UnsupportedAPITypeError(
286+
f"Supported API types: {self.get_supported_api_type_names()}")
267287

268288
# params
269289
params = {}

erniebot/resources/resource.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import operator
1617
import time
17-
from typing import (Any, AsyncIterator, Callable, cast, Dict, Iterator,
18-
Optional, Union)
18+
from typing import (Any, AsyncIterator, Callable, cast, ClassVar, Dict,
19+
Iterator, List, Optional, Tuple, Union)
1920

2021
from typing_extensions import final, Self
2122

2223
import erniebot.errors as errors
24+
from erniebot.api_types import APIType
2325
from erniebot.backends import build_backend, convert_str_to_api_type
2426
from erniebot.client import EBClient
2527
from erniebot.config import GlobalConfig
@@ -41,20 +43,18 @@ class EBResource(object):
4143
facilitate reuse of concrete implementations. Most methods of this class are
4244
marked as final (e.g., `request`, `arequest`), while some methods can be
4345
overridden to change the default behavior (e.g., `_create_config_dict`).
44-
45-
Attributes:
46-
cfg (Dict[str, Any]): Dictionary that stores global settings.
47-
client (erniebot.client.EBClient): Low-level client instance.
4846
"""
4947

48+
SUPPORTED_API_TYPES: ClassVar[Tuple[APIType, ...]] = ()
49+
5050
MAX_POLLING_RETRIES: int = 20
5151
POLLING_INTERVAL: int = 5
5252
_MAX_TOKEN_UPDATE_RETRIES: int = 3
5353

54-
def __init__(self, **kwargs: Any) -> None:
54+
def __init__(self, **config: Any) -> None:
5555
super().__init__()
5656

57-
self._cfg = self._create_config_dict(**kwargs)
57+
self._cfg = self._create_config_dict(config)
5858

5959
self.api_type = self._cfg['api_type']
6060
self.timeout = self._cfg['timeout']
@@ -67,6 +67,10 @@ def __init__(self, **kwargs: Any) -> None:
6767
def new_object(cls, **kwargs: Any) -> Self:
6868
return cls(**kwargs)
6969

70+
@classmethod
71+
def get_supported_api_type_names(cls) -> List[str]:
72+
return list(map(operator.attrgetter('name'), cls.SUPPORTED_API_TYPES))
73+
7074
@final
7175
def request(
7276
self,
@@ -285,7 +289,7 @@ async def _arequest(
285289
else:
286290
raise
287291

288-
def _create_config_dict(self, **overrides: Any) -> Dict[str, Any]:
292+
def _create_config_dict(self, overrides: Any) -> Dict[str, Any]:
289293
cfg_dict = cast(Dict[str, Any], GlobalConfig().create_dict(**overrides))
290294
api_type_str = cfg_dict['api_type']
291295
if not isinstance(api_type_str, str):

examples/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
gradio
2+
numpy

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
aiohttp
22
bce-python-sdk
33
colorlog
4-
numpy
54
requests >= 2.20
65
typing_extensions

tests/embedding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import erniebot
32

43
if __name__ == '__main__':

0 commit comments

Comments
 (0)