Skip to content

Commit 4579d4a

Browse files
committed
Fix mistake
1 parent 56756ee commit 4579d4a

1 file changed

Lines changed: 60 additions & 310 deletions

File tree

config.py

Lines changed: 60 additions & 310 deletions
Original file line numberDiff line numberDiff line change
@@ -17,319 +17,14 @@
1717
from pydantic_settings import BaseSettings, SettingsConfigDict
1818
from dotenv import load_dotenv
1919

20-
"""Azure authentication verification at startup."""
21-
import logging
22-
import time
23-
from dataclasses import dataclass, field
24-
from typing import Optional
25-
26-
from azure.identity import DefaultAzureCredential
27-
from azure.core.exceptions import (
28-
ClientAuthenticationError,
29-
HttpResponseError,
30-
ServiceRequestError,
31-
)
32-
from azure.keyvault.secrets import SecretClient
33-
from azure.appconfiguration import AzureAppConfigurationClient
34-
35-
logger = logging.getLogger(__name__)
36-
37-
38-
# Azure Resource Manager scope — works for any Azure resource for a basic token test
39-
ARM_SCOPE = "https://management.azure.com/.default"
40-
APPCONFIG_SCOPE = "https://azconfig.io/.default"
41-
KEYVAULT_SCOPE = "https://vault.azure.net/.default"
42-
20+
# At the top of config.py, add:
21+
from azure_auth import AzureAuthVerifier, AuthReport
4322

44-
@dataclass
45-
class AuthCheckResult:
46-
"""Result of an authentication check."""
47-
name: str
48-
success: bool
49-
duration_ms: int
50-
detail: str = ""
51-
error: Optional[str] = None
23+
# Module-level cache
24+
_auth_verified: bool = False
25+
_auth_report: Optional[AuthReport] = None
5226

5327

54-
@dataclass
55-
class AuthReport:
56-
"""Full report of all auth checks."""
57-
overall_success: bool
58-
checks: list[AuthCheckResult] = field(default_factory=list)
59-
identity_info: dict = field(default_factory=dict)
60-
61-
@property
62-
def failed_checks(self) -> list[AuthCheckResult]:
63-
return [c for c in self.checks if not c.success]
64-
65-
def summary(self) -> str:
66-
lines = [
67-
f"{'✅' if self.overall_success else '❌'} Azure auth: "
68-
f"{len([c for c in self.checks if c.success])}/{len(self.checks)} checks passed"
69-
]
70-
if self.identity_info:
71-
lines.append(
72-
f" Identity: client_id={self.identity_info.get('client_id', '?')} "
73-
f"tenant_id={self.identity_info.get('tenant_id', '?')}"
74-
)
75-
for c in self.checks:
76-
icon = "✅" if c.success else "❌"
77-
lines.append(f" {icon} {c.name} ({c.duration_ms}ms) {c.detail}")
78-
if c.error:
79-
lines.append(f" └─ {c.error}")
80-
return "\n".join(lines)
81-
82-
83-
class AzureAuthVerifier:
84-
"""Verifies Azure authentication and resource access at startup."""
85-
86-
def __init__(self, credential: Optional[DefaultAzureCredential] = None):
87-
self.credential = credential or DefaultAzureCredential(
88-
exclude_interactive_browser_credential=True,
89-
)
90-
self._identity_info: dict = {}
91-
92-
# -------------------------------------------------------------
93-
# Individual checks
94-
# -------------------------------------------------------------
95-
96-
def check_token_acquisition(self, scope: str = ARM_SCOPE) -> AuthCheckResult:
97-
"""Verify we can obtain *any* token from the credential chain."""
98-
start = time.monotonic()
99-
try:
100-
token = self.credential.get_token(scope)
101-
duration = int((time.monotonic() - start) * 1000)
102-
103-
# Decode the JWT (without verification — just to extract claims)
104-
self._identity_info = self._decode_token_claims(token.token)
105-
106-
return AuthCheckResult(
107-
name="Token acquisition",
108-
success=True,
109-
duration_ms=duration,
110-
detail=f"expires in {token.expires_on - int(time.time())}s",
111-
)
112-
except ClientAuthenticationError as e:
113-
return AuthCheckResult(
114-
name="Token acquisition",
115-
success=False,
116-
duration_ms=int((time.monotonic() - start) * 1000),
117-
error=self._format_auth_error(e),
118-
)
119-
except Exception as e:
120-
return AuthCheckResult(
121-
name="Token acquisition",
122-
success=False,
123-
duration_ms=int((time.monotonic() - start) * 1000),
124-
error=f"{type(e).__name__}: {e}",
125-
)
126-
127-
def check_key_vault(self, vault_url: str) -> AuthCheckResult:
128-
"""Verify Key Vault is reachable and we have read access."""
129-
start = time.monotonic()
130-
try:
131-
client = SecretClient(vault_url=vault_url, credential=self.credential)
132-
# List secret properties (lighter than fetching a specific secret)
133-
# We use `max_page_size=1` to minimize the response
134-
iterator = client.list_properties_of_secrets(max_page_size=1)
135-
# Consume the first page to actually trigger the request
136-
_ = next(iterator.by_page(), None)
137-
138-
duration = int((time.monotonic() - start) * 1000)
139-
return AuthCheckResult(
140-
name=f"Key Vault access ({vault_url})",
141-
success=True,
142-
duration_ms=duration,
143-
detail="list secrets OK",
144-
)
145-
except ClientAuthenticationError as e:
146-
return AuthCheckResult(
147-
name=f"Key Vault access ({vault_url})",
148-
success=False,
149-
duration_ms=int((time.monotonic() - start) * 1000),
150-
error=self._format_auth_error(e),
151-
)
152-
except HttpResponseError as e:
153-
return AuthCheckResult(
154-
name=f"Key Vault access ({vault_url})",
155-
success=False,
156-
duration_ms=int((time.monotonic() - start) * 1000),
157-
error=f"HTTP {e.status_code}: {e.reason} — check RBAC role 'Key Vault Secrets User'",
158-
)
159-
except ServiceRequestError as e:
160-
return AuthCheckResult(
161-
name=f"Key Vault access ({vault_url})",
162-
success=False,
163-
duration_ms=int((time.monotonic() - start) * 1000),
164-
error=f"Network error: {e} — check vault URL / DNS / firewall",
165-
)
166-
except Exception as e:
167-
return AuthCheckResult(
168-
name=f"Key Vault access ({vault_url})",
169-
success=False,
170-
duration_ms=int((time.monotonic() - start) * 1000),
171-
error=f"{type(e).__name__}: {e}",
172-
)
173-
174-
def check_specific_secret(self, vault_url: str, secret_name: str) -> AuthCheckResult:
175-
"""Verify we can actually read a specific required secret."""
176-
start = time.monotonic()
177-
try:
178-
client = SecretClient(vault_url=vault_url, credential=self.credential)
179-
secret = client.get_secret(secret_name)
180-
duration = int((time.monotonic() - start) * 1000)
181-
value_preview = "***" + secret.value[-4:] if secret.value else "(empty)"
182-
return AuthCheckResult(
183-
name=f"Secret '{secret_name}'",
184-
success=True,
185-
duration_ms=duration,
186-
detail=f"value={value_preview}",
187-
)
188-
except Exception as e:
189-
return AuthCheckResult(
190-
name=f"Secret '{secret_name}'",
191-
success=False,
192-
duration_ms=int((time.monotonic() - start) * 1000),
193-
error=f"{type(e).__name__}: {e}",
194-
)
195-
196-
def check_app_configuration(self, endpoint: str) -> AuthCheckResult:
197-
"""Verify App Configuration is reachable and we have read access."""
198-
start = time.monotonic()
199-
try:
200-
client = AzureAppConfigurationClient(
201-
base_url=endpoint,
202-
credential=self.credential,
203-
)
204-
# List config settings (lightweight check)
205-
iterator = client.list_configuration_settings()
206-
_ = next(iterator.by_page(), None)
207-
208-
duration = int((time.monotonic() - start) * 1000)
209-
return AuthCheckResult(
210-
name=f"App Configuration ({endpoint})",
211-
success=True,
212-
duration_ms=duration,
213-
detail="list settings OK",
214-
)
215-
except ClientAuthenticationError as e:
216-
return AuthCheckResult(
217-
name=f"App Configuration ({endpoint})",
218-
success=False,
219-
duration_ms=int((time.monotonic() - start) * 1000),
220-
error=self._format_auth_error(e),
221-
)
222-
except HttpResponseError as e:
223-
hint = ""
224-
if e.status_code == 403:
225-
hint = " — check RBAC role 'App Configuration Data Reader'"
226-
return AuthCheckResult(
227-
name=f"App Configuration ({endpoint})",
228-
success=False,
229-
duration_ms=int((time.monotonic() - start) * 1000),
230-
error=f"HTTP {e.status_code}: {e.reason}{hint}",
231-
)
232-
except Exception as e:
233-
return AuthCheckResult(
234-
name=f"App Configuration ({endpoint})",
235-
success=False,
236-
duration_ms=int((time.monotonic() - start) * 1000),
237-
error=f"{type(e).__name__}: {e}",
238-
)
239-
240-
# -------------------------------------------------------------
241-
# Orchestrator
242-
# -------------------------------------------------------------
243-
244-
def verify(
245-
self,
246-
*,
247-
key_vault_url: Optional[str] = None,
248-
appconfig_endpoint: Optional[str] = None,
249-
required_secrets: Optional[list[str]] = None,
250-
) -> AuthReport:
251-
"""Run all configured checks and return a full report."""
252-
report = AuthReport(overall_success=True)
253-
254-
# 1. Token acquisition (always)
255-
token_check = self.check_token_acquisition()
256-
report.checks.append(token_check)
257-
report.identity_info = self._identity_info
258-
if not token_check.success:
259-
# Can't continue if we can't even get a token
260-
report.overall_success = False
261-
return report
262-
263-
# 2. App Configuration (if configured)
264-
if appconfig_endpoint:
265-
check = self.check_app_configuration(appconfig_endpoint)
266-
report.checks.append(check)
267-
if not check.success:
268-
report.overall_success = False
269-
270-
# 3. Key Vault (if configured)
271-
if key_vault_url:
272-
check = self.check_key_vault(key_vault_url)
273-
report.checks.append(check)
274-
if not check.success:
275-
report.overall_success = False
276-
else:
277-
# Only check specific secrets if vault is reachable
278-
for secret_name in (required_secrets or []):
279-
sc = self.check_specific_secret(key_vault_url, secret_name)
280-
report.checks.append(sc)
281-
if not sc.success:
282-
report.overall_success = False
283-
284-
return report
285-
286-
# -------------------------------------------------------------
287-
# Helpers
288-
# -------------------------------------------------------------
289-
290-
@staticmethod
291-
def _decode_token_claims(token: str) -> dict:
292-
"""Decode JWT claims without verification (just for diagnostics)."""
293-
import base64
294-
import json
295-
try:
296-
parts = token.split(".")
297-
if len(parts) != 3:
298-
return {}
299-
# JWT base64 needs padding
300-
payload = parts[1] + "=" * (4 - len(parts[1]) % 4)
301-
claims = json.loads(base64.urlsafe_b64decode(payload))
302-
return {
303-
"client_id": claims.get("appid") or claims.get("azp", "unknown"),
304-
"tenant_id": claims.get("tid", "unknown"),
305-
"object_id": claims.get("oid", "unknown"),
306-
"identity_type": claims.get("idtyp", "unknown"),
307-
"expires_at": claims.get("exp", 0),
308-
}
309-
except Exception:
310-
return {}
311-
312-
@staticmethod
313-
def _format_auth_error(e: ClientAuthenticationError) -> str:
314-
"""Make ClientAuthenticationError messages actionable."""
315-
msg = str(e)
316-
hints = []
317-
318-
if "DefaultAzureCredential failed" in msg:
319-
hints.append("No credential in the chain succeeded. In K8s, check:")
320-
hints.append(" • Workload Identity label on pod: azure.workload.identity/use=true")
321-
hints.append(" • ServiceAccount annotation: azure.workload.identity/client-id")
322-
hints.append(" • Federated credential in Azure AD")
323-
if "AADSTS70021" in msg or "no matching federated identity" in msg.lower():
324-
hints.append("Federated credential subject doesn't match service account.")
325-
hints.append("Expected: system:serviceaccount:<namespace>:<sa-name>")
326-
if "AADSTS700016" in msg:
327-
hints.append("Application not found in tenant — check AZURE_CLIENT_ID.")
328-
329-
if hints:
330-
return msg + "\n " + "\n ".join(hints)
331-
return msg
332-
33328
load_dotenv()
33429
logger = logging.getLogger(__name__)
33530

@@ -665,6 +360,8 @@ def get_settings() -> Settings:
665360
with _settings_lock:
666361
# First time: build the settings object
667362
if _settings is None:
363+
verify_azure_auth(strict=True)
364+
668365
_settings = Settings()
669366
bootstrap = BootstrapSettings()
670367

@@ -715,3 +412,56 @@ def refresh_settings() -> Settings:
715412
if loader:
716413
loader.refresh()
717414
return get_settings()
415+
416+
def verify_azure_auth(strict: bool = True) -> AuthReport:
417+
"""Verify Azure authentication and resource access.
418+
419+
Args:
420+
strict: If True, raise RuntimeError on any failure.
421+
422+
Returns:
423+
AuthReport with detailed check results.
424+
"""
425+
global _auth_verified, _auth_report
426+
427+
bootstrap = BootstrapSettings()
428+
429+
# Skip if Azure isn't being used
430+
if not bootstrap.use_app_configuration and not bootstrap.use_key_vault:
431+
logger.info("Azure auth check skipped (no Azure services enabled)")
432+
_auth_verified = True
433+
_auth_report = AuthReport(overall_success=True)
434+
return _auth_report
435+
436+
logger.info("🔐 Verifying Azure authentication...")
437+
verifier = AzureAuthVerifier()
438+
439+
report = verifier.verify(
440+
key_vault_url=bootstrap.azure_key_vault_url if bootstrap.use_key_vault else None,
441+
appconfig_endpoint=(
442+
bootstrap.azure_appconfig_endpoint
443+
if bootstrap.use_app_configuration else None
444+
),
445+
# If using App Config with KV references, also verify the KV is reachable
446+
# (because the App Config provider will need it)
447+
# Specific secrets to verify exist:
448+
required_secrets=["openai-api-key"] if bootstrap.use_key_vault else None,
449+
)
450+
451+
logger.info("\n" + report.summary())
452+
453+
_auth_report = report
454+
_auth_verified = report.overall_success
455+
456+
if not report.overall_success and strict:
457+
raise RuntimeError(
458+
f"Azure authentication failed. {len(report.failed_checks)} check(s) failed:\n"
459+
+ "\n".join(f" - {c.name}: {c.error}" for c in report.failed_checks)
460+
)
461+
462+
return report
463+
464+
465+
def get_auth_report() -> Optional[AuthReport]:
466+
"""Return the last auth verification report."""
467+
return _auth_report

0 commit comments

Comments
 (0)