|
1 | | -import time |
2 | 1 | from abc import ABC |
3 | | -from typing import List, Optional, Union |
| 2 | +from typing import Awaitable, Callable, List, Optional, Union |
4 | 3 | from urllib.parse import urljoin |
5 | 4 |
|
6 | 5 | import aiohttp |
7 | 6 | import tiktoken |
8 | | -from azure.core.credentials import AccessToken, AzureKeyCredential |
| 7 | +from azure.core.credentials import AzureKeyCredential |
9 | 8 | from azure.core.credentials_async import AsyncTokenCredential |
| 9 | +from azure.identity.aio import get_bearer_token_provider |
10 | 10 | from openai import AsyncAzureOpenAI, AsyncOpenAI, RateLimitError |
11 | 11 | from tenacity import ( |
12 | 12 | AsyncRetrying, |
13 | 13 | retry_if_exception_type, |
14 | 14 | stop_after_attempt, |
15 | 15 | wait_random_exponential, |
16 | 16 | ) |
| 17 | +from typing_extensions import TypedDict |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class EmbeddingBatch: |
@@ -139,28 +140,29 @@ def __init__( |
139 | 140 | self.open_ai_service = open_ai_service |
140 | 141 | self.open_ai_deployment = open_ai_deployment |
141 | 142 | self.credential = credential |
142 | | - self.cached_token: Optional[AccessToken] = None |
143 | 143 |
|
144 | 144 | async def create_client(self) -> AsyncOpenAI: |
| 145 | + class AuthArgs(TypedDict, total=False): |
| 146 | + api_key: str |
| 147 | + azure_ad_token_provider: Callable[[], Union[str, Awaitable[str]]] |
| 148 | + |
| 149 | + auth_args = AuthArgs() |
| 150 | + if isinstance(self.credential, AzureKeyCredential): |
| 151 | + auth_args["api_key"] = self.credential.key |
| 152 | + elif isinstance(self.credential, AsyncTokenCredential): |
| 153 | + auth_args["azure_ad_token_provider"] = get_bearer_token_provider( |
| 154 | + self.credential, "https://cognitiveservices.azure.com/.default" |
| 155 | + ) |
| 156 | + else: |
| 157 | + raise TypeError("Invalid credential type") |
| 158 | + |
145 | 159 | return AsyncAzureOpenAI( |
146 | 160 | azure_endpoint=f"https://{self.open_ai_service}.openai.azure.com", |
147 | 161 | azure_deployment=self.open_ai_deployment, |
148 | | - api_key=await self.wrap_credential(), |
149 | 162 | api_version="2023-05-15", |
| 163 | + **auth_args, |
150 | 164 | ) |
151 | 165 |
|
152 | | - async def wrap_credential(self) -> str: |
153 | | - if isinstance(self.credential, AzureKeyCredential): |
154 | | - return self.credential.key |
155 | | - |
156 | | - if isinstance(self.credential, AsyncTokenCredential): |
157 | | - if not self.cached_token or self.cached_token.expires_on <= time.time(): |
158 | | - self.cached_token = await self.credential.get_token("https://cognitiveservices.azure.com/.default") |
159 | | - |
160 | | - return self.cached_token.token |
161 | | - |
162 | | - raise TypeError("Invalid credential type") |
163 | | - |
164 | 166 |
|
165 | 167 | class OpenAIEmbeddingService(OpenAIEmbeddings): |
166 | 168 | """ |
|
0 commit comments