4
4
import logging
5
5
import urllib .request
6
6
from dataclasses import dataclass , field
7
- from typing import Annotated , Any , Callable , Optional , Sequence
7
+ from typing import Annotated , Optional , Sequence
8
8
9
9
import jwt
10
10
from fastapi import HTTPException , Security , security , status
@@ -25,8 +25,6 @@ class OpenIdConnectAuth:
25
25
# Generated attributes
26
26
auth_scheme : SecurityBase = field (init = False )
27
27
jwks_client : jwt .PyJWKClient = field (init = False )
28
- validated_user : Callable [..., Any ] = field (init = False )
29
- maybe_validated_user : Callable [..., Any ] = field (init = False )
30
28
31
29
def __post_init__ (self ):
32
30
"""Initialize the OIDC authentication class."""
@@ -50,70 +48,80 @@ def __post_init__(self):
50
48
openIdConnectUrl = str (self .openid_configuration_url ),
51
49
auto_error = False ,
52
50
)
53
- self .validated_user = self ._build (auto_error = True )
54
- self .maybe_validated_user = self ._build (auto_error = False )
55
-
56
- def _build (self , auto_error : bool = True ):
57
- """Build a dependency for validating an OIDC token."""
58
-
59
- def valid_token_dependency (
60
- auth_header : Annotated [str , Security (self .auth_scheme )],
61
- required_scopes : security .SecurityScopes ,
62
- ):
63
- """Dependency to validate an OIDC token."""
64
- if not auth_header :
51
+
52
+ # Update annotations to support FastAPI's dependency injection
53
+ for endpoint in [self .validated_user , self .maybe_validated_user ]:
54
+ endpoint .__annotations__ ["auth_header" ] = Annotated [
55
+ str ,
56
+ Security (self .auth_scheme ),
57
+ ]
58
+
59
+ def maybe_validated_user (
60
+ self ,
61
+ auth_header : Annotated [str , Security (...)],
62
+ required_scopes : security .SecurityScopes ,
63
+ ):
64
+ """Dependency to validate an OIDC token."""
65
+ return self .validated_user (auth_header , required_scopes , auto_error = False )
66
+
67
+ def validated_user (
68
+ self ,
69
+ auth_header : Annotated [str , Security (...)],
70
+ required_scopes : security .SecurityScopes ,
71
+ auto_error : bool = True ,
72
+ ):
73
+ """Dependency to validate an OIDC token."""
74
+ if not auth_header :
75
+ if auto_error :
76
+ raise HTTPException (
77
+ status_code = status .HTTP_403_FORBIDDEN ,
78
+ detail = "Not authenticated" ,
79
+ )
80
+ return None
81
+
82
+ # Extract token from header
83
+ token_parts = auth_header .split (" " )
84
+ if len (token_parts ) != 2 or token_parts [0 ].lower () != "bearer" :
85
+ logger .error (f"Invalid token: { auth_header } " )
86
+ raise HTTPException (
87
+ status_code = status .HTTP_401_UNAUTHORIZED ,
88
+ detail = "Could not validate credentials" ,
89
+ headers = {"WWW-Authenticate" : "Bearer" },
90
+ )
91
+ [_ , token ] = token_parts
92
+
93
+ # Parse & validate token
94
+ try :
95
+ key = self .jwks_client .get_signing_key_from_jwt (token ).key
96
+ payload = jwt .decode (
97
+ token ,
98
+ key ,
99
+ algorithms = ["RS256" ],
100
+ # NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
101
+ audience = self .allowed_jwt_audiences ,
102
+ )
103
+ except (jwt .exceptions .InvalidTokenError , jwt .exceptions .DecodeError ) as e :
104
+ logger .exception (f"InvalidTokenError: { e = } " )
105
+ raise HTTPException (
106
+ status_code = status .HTTP_401_UNAUTHORIZED ,
107
+ detail = "Could not validate credentials" ,
108
+ headers = {"WWW-Authenticate" : "Bearer" },
109
+ ) from e
110
+
111
+ # Validate scopes (if required)
112
+ for scope in required_scopes .scopes :
113
+ if scope not in payload ["scope" ]:
65
114
if auto_error :
66
115
raise HTTPException (
67
- status_code = status .HTTP_403_FORBIDDEN ,
68
- detail = "Not authenticated" ,
116
+ status_code = status .HTTP_401_UNAUTHORIZED ,
117
+ detail = "Not enough permissions" ,
118
+ headers = {
119
+ "WWW-Authenticate" : f'Bearer scope="{ required_scopes .scope_str } "'
120
+ },
69
121
)
70
122
return None
71
123
72
- # Extract token from header
73
- token_parts = auth_header .split (" " )
74
- if len (token_parts ) != 2 or token_parts [0 ].lower () != "bearer" :
75
- logger .error (f"Invalid token: { auth_header } " )
76
- raise HTTPException (
77
- status_code = status .HTTP_401_UNAUTHORIZED ,
78
- detail = "Could not validate credentials" ,
79
- headers = {"WWW-Authenticate" : "Bearer" },
80
- )
81
- [_ , token ] = token_parts
82
-
83
- # Parse & validate token
84
- try :
85
- key = self .jwks_client .get_signing_key_from_jwt (token ).key
86
- payload = jwt .decode (
87
- token ,
88
- key ,
89
- algorithms = ["RS256" ],
90
- # NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
91
- audience = self .allowed_jwt_audiences ,
92
- )
93
- except (jwt .exceptions .InvalidTokenError , jwt .exceptions .DecodeError ) as e :
94
- logger .exception (f"InvalidTokenError: { e = } " )
95
- raise HTTPException (
96
- status_code = status .HTTP_401_UNAUTHORIZED ,
97
- detail = "Could not validate credentials" ,
98
- headers = {"WWW-Authenticate" : "Bearer" },
99
- ) from e
100
-
101
- # Validate scopes (if required)
102
- for scope in required_scopes .scopes :
103
- if scope not in payload ["scope" ]:
104
- if auto_error :
105
- raise HTTPException (
106
- status_code = status .HTTP_401_UNAUTHORIZED ,
107
- detail = "Not enough permissions" ,
108
- headers = {
109
- "WWW-Authenticate" : f'Bearer scope="{ required_scopes .scope_str } "'
110
- },
111
- )
112
- return None
113
-
114
- return payload
115
-
116
- return valid_token_dependency
124
+ return payload
117
125
118
126
119
127
class OidcFetchError (Exception ):
0 commit comments