|
17 | 17 | from pydantic_settings import BaseSettings, SettingsConfigDict |
18 | 18 | from dotenv import load_dotenv |
19 | 19 |
|
20 | | -"""Azure authentication verification at startup.""" |
21 | | -import logging |
22 | | -import time |
23 | | -from dataclasses import dataclass, field |
24 | | -from typing import Optional |
25 | | - |
26 | | -from azure.identity import DefaultAzureCredential |
27 | | -from azure.core.exceptions import ( |
28 | | - ClientAuthenticationError, |
29 | | - HttpResponseError, |
30 | | - ServiceRequestError, |
31 | | -) |
32 | | -from azure.keyvault.secrets import SecretClient |
33 | | -from azure.appconfiguration import AzureAppConfigurationClient |
34 | | - |
35 | | -logger = logging.getLogger(__name__) |
36 | | - |
37 | | - |
38 | | -# Azure Resource Manager scope — works for any Azure resource for a basic token test |
39 | | -ARM_SCOPE = "https://management.azure.com/.default" |
40 | | -APPCONFIG_SCOPE = "https://azconfig.io/.default" |
41 | | -KEYVAULT_SCOPE = "https://vault.azure.net/.default" |
42 | | - |
| 20 | +# At the top of config.py, add: |
| 21 | +from azure_auth import AzureAuthVerifier, AuthReport |
43 | 22 |
|
44 | | -@dataclass |
45 | | -class AuthCheckResult: |
46 | | - """Result of an authentication check.""" |
47 | | - name: str |
48 | | - success: bool |
49 | | - duration_ms: int |
50 | | - detail: str = "" |
51 | | - error: Optional[str] = None |
| 23 | +# Module-level cache |
| 24 | +_auth_verified: bool = False |
| 25 | +_auth_report: Optional[AuthReport] = None |
52 | 26 |
|
53 | 27 |
|
54 | | -@dataclass |
55 | | -class AuthReport: |
56 | | - """Full report of all auth checks.""" |
57 | | - overall_success: bool |
58 | | - checks: list[AuthCheckResult] = field(default_factory=list) |
59 | | - identity_info: dict = field(default_factory=dict) |
60 | | - |
61 | | - @property |
62 | | - def failed_checks(self) -> list[AuthCheckResult]: |
63 | | - return [c for c in self.checks if not c.success] |
64 | | - |
65 | | - def summary(self) -> str: |
66 | | - lines = [ |
67 | | - f"{'✅' if self.overall_success else '❌'} Azure auth: " |
68 | | - f"{len([c for c in self.checks if c.success])}/{len(self.checks)} checks passed" |
69 | | - ] |
70 | | - if self.identity_info: |
71 | | - lines.append( |
72 | | - f" Identity: client_id={self.identity_info.get('client_id', '?')} " |
73 | | - f"tenant_id={self.identity_info.get('tenant_id', '?')}" |
74 | | - ) |
75 | | - for c in self.checks: |
76 | | - icon = "✅" if c.success else "❌" |
77 | | - lines.append(f" {icon} {c.name} ({c.duration_ms}ms) {c.detail}") |
78 | | - if c.error: |
79 | | - lines.append(f" └─ {c.error}") |
80 | | - return "\n".join(lines) |
81 | | - |
82 | | - |
83 | | -class AzureAuthVerifier: |
84 | | - """Verifies Azure authentication and resource access at startup.""" |
85 | | - |
86 | | - def __init__(self, credential: Optional[DefaultAzureCredential] = None): |
87 | | - self.credential = credential or DefaultAzureCredential( |
88 | | - exclude_interactive_browser_credential=True, |
89 | | - ) |
90 | | - self._identity_info: dict = {} |
91 | | - |
92 | | - # ------------------------------------------------------------- |
93 | | - # Individual checks |
94 | | - # ------------------------------------------------------------- |
95 | | - |
96 | | - def check_token_acquisition(self, scope: str = ARM_SCOPE) -> AuthCheckResult: |
97 | | - """Verify we can obtain *any* token from the credential chain.""" |
98 | | - start = time.monotonic() |
99 | | - try: |
100 | | - token = self.credential.get_token(scope) |
101 | | - duration = int((time.monotonic() - start) * 1000) |
102 | | - |
103 | | - # Decode the JWT (without verification — just to extract claims) |
104 | | - self._identity_info = self._decode_token_claims(token.token) |
105 | | - |
106 | | - return AuthCheckResult( |
107 | | - name="Token acquisition", |
108 | | - success=True, |
109 | | - duration_ms=duration, |
110 | | - detail=f"expires in {token.expires_on - int(time.time())}s", |
111 | | - ) |
112 | | - except ClientAuthenticationError as e: |
113 | | - return AuthCheckResult( |
114 | | - name="Token acquisition", |
115 | | - success=False, |
116 | | - duration_ms=int((time.monotonic() - start) * 1000), |
117 | | - error=self._format_auth_error(e), |
118 | | - ) |
119 | | - except Exception as e: |
120 | | - return AuthCheckResult( |
121 | | - name="Token acquisition", |
122 | | - success=False, |
123 | | - duration_ms=int((time.monotonic() - start) * 1000), |
124 | | - error=f"{type(e).__name__}: {e}", |
125 | | - ) |
126 | | - |
127 | | - def check_key_vault(self, vault_url: str) -> AuthCheckResult: |
128 | | - """Verify Key Vault is reachable and we have read access.""" |
129 | | - start = time.monotonic() |
130 | | - try: |
131 | | - client = SecretClient(vault_url=vault_url, credential=self.credential) |
132 | | - # List secret properties (lighter than fetching a specific secret) |
133 | | - # We use `max_page_size=1` to minimize the response |
134 | | - iterator = client.list_properties_of_secrets(max_page_size=1) |
135 | | - # Consume the first page to actually trigger the request |
136 | | - _ = next(iterator.by_page(), None) |
137 | | - |
138 | | - duration = int((time.monotonic() - start) * 1000) |
139 | | - return AuthCheckResult( |
140 | | - name=f"Key Vault access ({vault_url})", |
141 | | - success=True, |
142 | | - duration_ms=duration, |
143 | | - detail="list secrets OK", |
144 | | - ) |
145 | | - except ClientAuthenticationError as e: |
146 | | - return AuthCheckResult( |
147 | | - name=f"Key Vault access ({vault_url})", |
148 | | - success=False, |
149 | | - duration_ms=int((time.monotonic() - start) * 1000), |
150 | | - error=self._format_auth_error(e), |
151 | | - ) |
152 | | - except HttpResponseError as e: |
153 | | - return AuthCheckResult( |
154 | | - name=f"Key Vault access ({vault_url})", |
155 | | - success=False, |
156 | | - duration_ms=int((time.monotonic() - start) * 1000), |
157 | | - error=f"HTTP {e.status_code}: {e.reason} — check RBAC role 'Key Vault Secrets User'", |
158 | | - ) |
159 | | - except ServiceRequestError as e: |
160 | | - return AuthCheckResult( |
161 | | - name=f"Key Vault access ({vault_url})", |
162 | | - success=False, |
163 | | - duration_ms=int((time.monotonic() - start) * 1000), |
164 | | - error=f"Network error: {e} — check vault URL / DNS / firewall", |
165 | | - ) |
166 | | - except Exception as e: |
167 | | - return AuthCheckResult( |
168 | | - name=f"Key Vault access ({vault_url})", |
169 | | - success=False, |
170 | | - duration_ms=int((time.monotonic() - start) * 1000), |
171 | | - error=f"{type(e).__name__}: {e}", |
172 | | - ) |
173 | | - |
174 | | - def check_specific_secret(self, vault_url: str, secret_name: str) -> AuthCheckResult: |
175 | | - """Verify we can actually read a specific required secret.""" |
176 | | - start = time.monotonic() |
177 | | - try: |
178 | | - client = SecretClient(vault_url=vault_url, credential=self.credential) |
179 | | - secret = client.get_secret(secret_name) |
180 | | - duration = int((time.monotonic() - start) * 1000) |
181 | | - value_preview = "***" + secret.value[-4:] if secret.value else "(empty)" |
182 | | - return AuthCheckResult( |
183 | | - name=f"Secret '{secret_name}'", |
184 | | - success=True, |
185 | | - duration_ms=duration, |
186 | | - detail=f"value={value_preview}", |
187 | | - ) |
188 | | - except Exception as e: |
189 | | - return AuthCheckResult( |
190 | | - name=f"Secret '{secret_name}'", |
191 | | - success=False, |
192 | | - duration_ms=int((time.monotonic() - start) * 1000), |
193 | | - error=f"{type(e).__name__}: {e}", |
194 | | - ) |
195 | | - |
196 | | - def check_app_configuration(self, endpoint: str) -> AuthCheckResult: |
197 | | - """Verify App Configuration is reachable and we have read access.""" |
198 | | - start = time.monotonic() |
199 | | - try: |
200 | | - client = AzureAppConfigurationClient( |
201 | | - base_url=endpoint, |
202 | | - credential=self.credential, |
203 | | - ) |
204 | | - # List config settings (lightweight check) |
205 | | - iterator = client.list_configuration_settings() |
206 | | - _ = next(iterator.by_page(), None) |
207 | | - |
208 | | - duration = int((time.monotonic() - start) * 1000) |
209 | | - return AuthCheckResult( |
210 | | - name=f"App Configuration ({endpoint})", |
211 | | - success=True, |
212 | | - duration_ms=duration, |
213 | | - detail="list settings OK", |
214 | | - ) |
215 | | - except ClientAuthenticationError as e: |
216 | | - return AuthCheckResult( |
217 | | - name=f"App Configuration ({endpoint})", |
218 | | - success=False, |
219 | | - duration_ms=int((time.monotonic() - start) * 1000), |
220 | | - error=self._format_auth_error(e), |
221 | | - ) |
222 | | - except HttpResponseError as e: |
223 | | - hint = "" |
224 | | - if e.status_code == 403: |
225 | | - hint = " — check RBAC role 'App Configuration Data Reader'" |
226 | | - return AuthCheckResult( |
227 | | - name=f"App Configuration ({endpoint})", |
228 | | - success=False, |
229 | | - duration_ms=int((time.monotonic() - start) * 1000), |
230 | | - error=f"HTTP {e.status_code}: {e.reason}{hint}", |
231 | | - ) |
232 | | - except Exception as e: |
233 | | - return AuthCheckResult( |
234 | | - name=f"App Configuration ({endpoint})", |
235 | | - success=False, |
236 | | - duration_ms=int((time.monotonic() - start) * 1000), |
237 | | - error=f"{type(e).__name__}: {e}", |
238 | | - ) |
239 | | - |
240 | | - # ------------------------------------------------------------- |
241 | | - # Orchestrator |
242 | | - # ------------------------------------------------------------- |
243 | | - |
244 | | - def verify( |
245 | | - self, |
246 | | - *, |
247 | | - key_vault_url: Optional[str] = None, |
248 | | - appconfig_endpoint: Optional[str] = None, |
249 | | - required_secrets: Optional[list[str]] = None, |
250 | | - ) -> AuthReport: |
251 | | - """Run all configured checks and return a full report.""" |
252 | | - report = AuthReport(overall_success=True) |
253 | | - |
254 | | - # 1. Token acquisition (always) |
255 | | - token_check = self.check_token_acquisition() |
256 | | - report.checks.append(token_check) |
257 | | - report.identity_info = self._identity_info |
258 | | - if not token_check.success: |
259 | | - # Can't continue if we can't even get a token |
260 | | - report.overall_success = False |
261 | | - return report |
262 | | - |
263 | | - # 2. App Configuration (if configured) |
264 | | - if appconfig_endpoint: |
265 | | - check = self.check_app_configuration(appconfig_endpoint) |
266 | | - report.checks.append(check) |
267 | | - if not check.success: |
268 | | - report.overall_success = False |
269 | | - |
270 | | - # 3. Key Vault (if configured) |
271 | | - if key_vault_url: |
272 | | - check = self.check_key_vault(key_vault_url) |
273 | | - report.checks.append(check) |
274 | | - if not check.success: |
275 | | - report.overall_success = False |
276 | | - else: |
277 | | - # Only check specific secrets if vault is reachable |
278 | | - for secret_name in (required_secrets or []): |
279 | | - sc = self.check_specific_secret(key_vault_url, secret_name) |
280 | | - report.checks.append(sc) |
281 | | - if not sc.success: |
282 | | - report.overall_success = False |
283 | | - |
284 | | - return report |
285 | | - |
286 | | - # ------------------------------------------------------------- |
287 | | - # Helpers |
288 | | - # ------------------------------------------------------------- |
289 | | - |
290 | | - @staticmethod |
291 | | - def _decode_token_claims(token: str) -> dict: |
292 | | - """Decode JWT claims without verification (just for diagnostics).""" |
293 | | - import base64 |
294 | | - import json |
295 | | - try: |
296 | | - parts = token.split(".") |
297 | | - if len(parts) != 3: |
298 | | - return {} |
299 | | - # JWT base64 needs padding |
300 | | - payload = parts[1] + "=" * (4 - len(parts[1]) % 4) |
301 | | - claims = json.loads(base64.urlsafe_b64decode(payload)) |
302 | | - return { |
303 | | - "client_id": claims.get("appid") or claims.get("azp", "unknown"), |
304 | | - "tenant_id": claims.get("tid", "unknown"), |
305 | | - "object_id": claims.get("oid", "unknown"), |
306 | | - "identity_type": claims.get("idtyp", "unknown"), |
307 | | - "expires_at": claims.get("exp", 0), |
308 | | - } |
309 | | - except Exception: |
310 | | - return {} |
311 | | - |
312 | | - @staticmethod |
313 | | - def _format_auth_error(e: ClientAuthenticationError) -> str: |
314 | | - """Make ClientAuthenticationError messages actionable.""" |
315 | | - msg = str(e) |
316 | | - hints = [] |
317 | | - |
318 | | - if "DefaultAzureCredential failed" in msg: |
319 | | - hints.append("No credential in the chain succeeded. In K8s, check:") |
320 | | - hints.append(" • Workload Identity label on pod: azure.workload.identity/use=true") |
321 | | - hints.append(" • ServiceAccount annotation: azure.workload.identity/client-id") |
322 | | - hints.append(" • Federated credential in Azure AD") |
323 | | - if "AADSTS70021" in msg or "no matching federated identity" in msg.lower(): |
324 | | - hints.append("Federated credential subject doesn't match service account.") |
325 | | - hints.append("Expected: system:serviceaccount:<namespace>:<sa-name>") |
326 | | - if "AADSTS700016" in msg: |
327 | | - hints.append("Application not found in tenant — check AZURE_CLIENT_ID.") |
328 | | - |
329 | | - if hints: |
330 | | - return msg + "\n " + "\n ".join(hints) |
331 | | - return msg |
332 | | - |
333 | 28 | load_dotenv() |
334 | 29 | logger = logging.getLogger(__name__) |
335 | 30 |
|
@@ -665,6 +360,8 @@ def get_settings() -> Settings: |
665 | 360 | with _settings_lock: |
666 | 361 | # First time: build the settings object |
667 | 362 | if _settings is None: |
| 363 | + verify_azure_auth(strict=True) |
| 364 | + |
668 | 365 | _settings = Settings() |
669 | 366 | bootstrap = BootstrapSettings() |
670 | 367 |
|
@@ -715,3 +412,56 @@ def refresh_settings() -> Settings: |
715 | 412 | if loader: |
716 | 413 | loader.refresh() |
717 | 414 | return get_settings() |
| 415 | + |
| 416 | +def verify_azure_auth(strict: bool = True) -> AuthReport: |
| 417 | + """Verify Azure authentication and resource access. |
| 418 | + |
| 419 | + Args: |
| 420 | + strict: If True, raise RuntimeError on any failure. |
| 421 | + |
| 422 | + Returns: |
| 423 | + AuthReport with detailed check results. |
| 424 | + """ |
| 425 | + global _auth_verified, _auth_report |
| 426 | + |
| 427 | + bootstrap = BootstrapSettings() |
| 428 | + |
| 429 | + # Skip if Azure isn't being used |
| 430 | + if not bootstrap.use_app_configuration and not bootstrap.use_key_vault: |
| 431 | + logger.info("Azure auth check skipped (no Azure services enabled)") |
| 432 | + _auth_verified = True |
| 433 | + _auth_report = AuthReport(overall_success=True) |
| 434 | + return _auth_report |
| 435 | + |
| 436 | + logger.info("🔐 Verifying Azure authentication...") |
| 437 | + verifier = AzureAuthVerifier() |
| 438 | + |
| 439 | + report = verifier.verify( |
| 440 | + key_vault_url=bootstrap.azure_key_vault_url if bootstrap.use_key_vault else None, |
| 441 | + appconfig_endpoint=( |
| 442 | + bootstrap.azure_appconfig_endpoint |
| 443 | + if bootstrap.use_app_configuration else None |
| 444 | + ), |
| 445 | + # If using App Config with KV references, also verify the KV is reachable |
| 446 | + # (because the App Config provider will need it) |
| 447 | + # Specific secrets to verify exist: |
| 448 | + required_secrets=["openai-api-key"] if bootstrap.use_key_vault else None, |
| 449 | + ) |
| 450 | + |
| 451 | + logger.info("\n" + report.summary()) |
| 452 | + |
| 453 | + _auth_report = report |
| 454 | + _auth_verified = report.overall_success |
| 455 | + |
| 456 | + if not report.overall_success and strict: |
| 457 | + raise RuntimeError( |
| 458 | + f"Azure authentication failed. {len(report.failed_checks)} check(s) failed:\n" |
| 459 | + + "\n".join(f" - {c.name}: {c.error}" for c in report.failed_checks) |
| 460 | + ) |
| 461 | + |
| 462 | + return report |
| 463 | + |
| 464 | + |
| 465 | +def get_auth_report() -> Optional[AuthReport]: |
| 466 | + """Return the last auth verification report.""" |
| 467 | + return _auth_report |
0 commit comments