diff --git a/CHANGELOG.md b/CHANGELOG.md index 10a1609..ea94098 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# 1.6.7 +- Added optional substring exempt path matching when `exempt_prefix_match` is `True` + # 1.6.6 - Support long-lived connections in ASGI middleware diff --git a/mauth_client/middlewares/asgi.py b/mauth_client/middlewares/asgi.py index c571b25..8b079aa 100644 --- a/mauth_client/middlewares/asgi.py +++ b/mauth_client/middlewares/asgi.py @@ -19,16 +19,17 @@ ) from mauth_client.signable import RequestSignable from mauth_client.signed import Signed -from mauth_client.utils import decode +from mauth_client.utils import decode, is_exempt_request_path logger = logging.getLogger("mauth_asgi") class MAuthASGIMiddleware: - def __init__(self, app: ASGI3Application, exempt: Optional[set] = None) -> None: + def __init__(self, app: ASGI3Application, exempt: Optional[set] = None, exempt_prefix_match: bool = False) -> None: self._validate_configs() self.app = app self.exempt = exempt.copy() if exempt else set() + self.exempt_prefix_match = exempt_prefix_match async def __call__( self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @@ -40,6 +41,9 @@ async def __call__( if path in self.exempt: return await self.app(scope, receive, send) + if self.exempt_prefix_match and is_exempt_request_path(path, self.exempt): + return await self.app(scope, receive, send) + query_string = scope["query_string"] url = f"{path}?{decode(query_string)}" if query_string else path headers = {decode(k): decode(v) for k, v in scope["headers"]} diff --git a/mauth_client/middlewares/wsgi.py b/mauth_client/middlewares/wsgi.py index 44657bf..a456d00 100644 --- a/mauth_client/middlewares/wsgi.py +++ b/mauth_client/middlewares/wsgi.py @@ -14,15 +14,17 @@ from mauth_client.signable import RequestSignable from mauth_client.signed import Signed +from mauth_client.utils import is_exempt_request_path logger = logging.getLogger("mauth_wsgi") class MAuthWSGIMiddleware: - def __init__(self, app, exempt=None): + def __init__(self, app, exempt=None, exempt_prefix_match=False): self._validate_configs() self.app = app self.exempt = exempt.copy() if exempt else set() + self.exempt_prefix_match = exempt_prefix_match def __call__(self, environ, start_response): path = environ.get("PATH_INFO", "") @@ -30,6 +32,9 @@ def __call__(self, environ, start_response): if path in self.exempt: return self.app(environ, start_response) + if self.exempt_prefix_match and is_exempt_request_path(path, self.exempt): + return self.app(environ, start_response) + signable = RequestSignable( method=environ["REQUEST_METHOD"], url=self._extract_url(environ), diff --git a/mauth_client/utils.py b/mauth_client/utils.py index 055e18d..1ed589a 100644 --- a/mauth_client/utils.py +++ b/mauth_client/utils.py @@ -32,3 +32,36 @@ def decode(byte_string: bytes) -> str: except UnicodeDecodeError: encoding = charset_normalizer.detect(byte_string)["encoding"] return byte_string.decode(encoding) + + +def is_exempt_request_path(path: str, exempt: set) -> bool: + """ + Check if a request path should be exempt from authentication based on prefix matching. + + This function performs prefix matching with path separator boundary checking to prevent + false positives. A path matches an exempt prefix only if it starts with the exempt path + followed by a path separator ('/'). + + :param str path: The request path to check (e.g., '/health/live', '/api/users') + :param set exempt: Set of exempt path prefixes (e.g., {'/health', '/metrics'}) + :return: True if the path matches any exempt prefix, False otherwise + :rtype: bool + + Examples: + Matching cases (returns True): + - path='/health/live', exempt={'/health'} -> True + - path='/health/ready', exempt={'/health'} -> True + - path='/metrics/prometheus', exempt={'/metrics'} -> True + + Non-matching cases (returns False): + - path='/health', exempt={'/health'} -> False (exact match without trailing slash) + - path='/api-admin', exempt={'/api'} -> False (not a path separator boundary) + - path='/app_status_admin', exempt={'/app_status'} -> False (underscore, not separator) + - path='/healthcare', exempt={'/health'} -> False (different path) + """ + for exempt_path in exempt: + # Exact match or prefix match with path separator + # For instance this prevents /api matching /api-admin or /app_status matching /app_status_admin + if path.startswith(exempt_path.rstrip('/') + '/'): + return True + return False diff --git a/pyproject.toml b/pyproject.toml index fb879ac..4dc7c1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mauth-client" -version = "1.6.6" +version = "1.6.7" description = "MAuth Client for Python" repository = "https://github.com/mdsol/mauth-client-python" authors = ["Medidata Solutions "] diff --git a/tests/middlewares/asgi_test.py b/tests/middlewares/asgi_test.py index d5624e2..7bb066f 100644 --- a/tests/middlewares/asgi_test.py +++ b/tests/middlewares/asgi_test.py @@ -278,3 +278,79 @@ async def mock_receive(): self.assertEqual(call_order[0], ("body", "http.request")) self.assertEqual(call_order[1], ("disconnect", "http.disconnect")) self.assertEqual(receive_calls, 2) # Called once for auth, once from app + + +class TestMAuthASGIMiddlewareWithPrefixMatch(unittest.TestCase): + def setUp(self): + self.app_uuid = str(uuid4()) + Config.APP_UUID = self.app_uuid + Config.MAUTH_URL = "https://mauth.com" + Config.MAUTH_API_VERSION = "v1" + Config.PRIVATE_KEY = "key" + + self.app = FastAPI() + self.app.add_middleware( + MAuthASGIMiddleware, + exempt={"/health", "/metrics"}, + exempt_prefix_match=True + ) + + @self.app.get("/") + async def root(): + return {"msg": "authenticated"} + + @self.app.get("/health") + async def health_exact(): + return {"msg": "exact health"} + + @self.app.get("/health/live") + async def health_live(): + return {"msg": "health live"} + + @self.app.get("/health/ready") + async def health_ready(): + return {"msg": "health ready"} + + @self.app.get("/metrics/prometheus") + async def metrics(): + return {"msg": "metrics"} + + @self.app.get("/api/health") + async def api_health(): + return {"msg": "api health"} + + self.client = TestClient(self.app) + + def test_prefix_match_allows_nested_paths(self): + """Test that nested paths under exempt prefix are allowed""" + response = self.client.get("/health/live") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"msg": "health live"}) + + response = self.client.get("/health/ready") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"msg": "health ready"}) + + response = self.client.get("/metrics/prometheus") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"msg": "metrics"}) + + def test_prefix_match_blocks_similar_paths(self): + """Test that similar but non-matching paths are still blocked""" + response = self.client.get("/api/health") + self.assertEqual(response.status_code, 401) + + def test_prefix_match_allows_exact_match_in_exempt_set(self): + """Test that exact match in exempt set is allowed (from exact match check)""" + response = self.client.get("/health") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"msg": "exact health"}) + + @patch.object(LocalAuthenticator, "is_authentic") + def test_prefix_match_still_authenticates_non_exempt_paths(self, is_authentic_mock): + """Test that non-exempt paths still require authentication""" + is_authentic_mock.return_value = (True, 200, "") + + response = self.client.get("/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"msg": "authenticated"}) diff --git a/tests/middlewares/wsgi_test.py b/tests/middlewares/wsgi_test.py index aecac85..9e086a9 100644 --- a/tests/middlewares/wsgi_test.py +++ b/tests/middlewares/wsgi_test.py @@ -192,3 +192,79 @@ def post_test(): self.assertEqual(response.status_code, 200) self.assertEqual(response.json, body) + + +class TestMAuthWSGIMiddlewareWithPrefixMatch(unittest.TestCase): + def setUp(self): + self.app_uuid = str(uuid4()) + Config.APP_UUID = self.app_uuid + Config.MAUTH_URL = "https://mauth.com" + Config.MAUTH_API_VERSION = "v1" + Config.PRIVATE_KEY = "key" + + self.app = Flask("Test App") + self.app.wsgi_app = MAuthWSGIMiddleware( + self.app.wsgi_app, + exempt={"/health", "/metrics"}, + exempt_prefix_match=True + ) + + @self.app.get("/") + def root(): + return "authenticated!" + + @self.app.get("/health") + def health_exact(): + return "exact health" + + @self.app.get("/health/live") + def health_live(): + return "health live" + + @self.app.get("/health/ready") + def health_ready(): + return "health ready" + + @self.app.get("/metrics/prometheus") + def metrics(): + return "metrics" + + @self.app.get("/api/health") + def api_health(): + return "api health" + + self.client = self.app.test_client() + + def test_prefix_match_allows_nested_paths(self): + """Test that nested paths under exempt prefix are allowed""" + response = self.client.get("/health/live") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_data(as_text=True), "health live") + + response = self.client.get("/health/ready") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_data(as_text=True), "health ready") + + response = self.client.get("/metrics/prometheus") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_data(as_text=True), "metrics") + + def test_prefix_match_blocks_similar_paths(self): + """Test that similar but non-matching paths are still blocked""" + response = self.client.get("/api/health") + self.assertEqual(response.status_code, 401) + + def test_prefix_match_allows_exact_match_in_exempt_set(self): + """Test that exact match in exempt set is allowed (from exact match check)""" + response = self.client.get("/health") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_data(as_text=True), "exact health") + + @patch.object(LocalAuthenticator, "is_authentic") + def test_prefix_match_still_authenticates_non_exempt_paths(self, is_authentic_mock): + """Test that non-exempt paths still require authentication""" + is_authentic_mock.return_value = (True, 200, "") + + response = self.client.get("/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_data(as_text=True), "authenticated!") diff --git a/tests/utils_test.py b/tests/utils_test.py new file mode 100644 index 0000000..3dbe22f --- /dev/null +++ b/tests/utils_test.py @@ -0,0 +1,172 @@ +import unittest +from mauth_client.utils import is_exempt_request_path + + +class TestIsExemptRequestPath(unittest.TestCase): + """Test the is_exempt_request_path utility function.""" + + def test_exact_prefix_match_with_trailing_slash(self): + """Test that paths with trailing slashes in exempt set match correctly.""" + exempt = {"/health/"} + self.assertTrue(is_exempt_request_path("/health/check", exempt)) + self.assertTrue(is_exempt_request_path("/health/status", exempt)) + + def test_exact_prefix_match_without_trailing_slash(self): + """Test that paths without trailing slashes in exempt set match correctly.""" + exempt = {"/health"} + self.assertTrue(is_exempt_request_path("/health/check", exempt)) + self.assertTrue(is_exempt_request_path("/health/status", exempt)) + + def test_no_match_similar_prefix(self): + """Test that similar but non-matching paths return False.""" + exempt = {"/api"} + self.assertFalse(is_exempt_request_path("/api-admin", exempt)) + self.assertFalse(is_exempt_request_path("/api-v2", exempt)) + self.assertFalse(is_exempt_request_path("/apis", exempt)) + + def test_no_match_similar_prefix_with_underscore(self): + """Test prevention of false matches with underscores.""" + exempt = {"/app_status"} + self.assertFalse(is_exempt_request_path("/app_status_admin", exempt)) + self.assertFalse(is_exempt_request_path("/app_status_internal", exempt)) + + def test_nested_path_match(self): + """Test deeply nested paths under exempt prefix.""" + exempt = {"/api"} + self.assertTrue(is_exempt_request_path("/api/v1/users", exempt)) + self.assertTrue(is_exempt_request_path("/api/v2/products/123", exempt)) + + def test_exact_match_path_returns_false(self): + """Test that exact match of exempt path without trailing slash returns False.""" + exempt = {"/health"} + # This is by design - the function checks for prefix with '/' + self.assertFalse(is_exempt_request_path("/health", exempt)) + + def test_exact_match_with_trailing_slash_in_path(self): + """Test exact match when request path has trailing slash.""" + exempt = {"/health"} + self.assertTrue(is_exempt_request_path("/health/", exempt)) + + def test_multiple_exempt_paths(self): + """Test with multiple exempt paths.""" + exempt = {"/health", "/metrics", "/status"} + self.assertTrue(is_exempt_request_path("/health/check", exempt)) + self.assertTrue(is_exempt_request_path("/metrics/prometheus", exempt)) + self.assertTrue(is_exempt_request_path("/status/ready", exempt)) + self.assertFalse(is_exempt_request_path("/api/users", exempt)) + + def test_empty_exempt_set(self): + """Test with empty exempt set returns False.""" + exempt = set() + self.assertFalse(is_exempt_request_path("/any/path", exempt)) + self.assertFalse(is_exempt_request_path("/", exempt)) + + def test_root_path_exempt(self): + """Test root path exemption.""" + exempt = {"/"} + self.assertTrue(is_exempt_request_path("/anything", exempt)) + self.assertTrue(is_exempt_request_path("/api/users", exempt)) + + def test_no_leading_slash_in_exempt(self): + """Test behavior when exempt path doesn't have leading slash.""" + exempt = {"health"} + # Should not match since it becomes "health/" + self.assertFalse(is_exempt_request_path("/health/check", exempt)) + + def test_path_with_special_characters(self): + """Test paths with special characters.""" + exempt = {"/api-v1"} + self.assertTrue(is_exempt_request_path("/api-v1/users", exempt)) + self.assertFalse(is_exempt_request_path("/api-v2/users", exempt)) + + def test_path_with_numbers(self): + """Test paths with numbers.""" + exempt = {"/v1"} + self.assertTrue(is_exempt_request_path("/v1/api", exempt)) + self.assertFalse(is_exempt_request_path("/v2/api", exempt)) + + def test_case_sensitive_matching(self): + """Test that path matching is case-sensitive.""" + exempt = {"/Health"} + self.assertFalse(is_exempt_request_path("/health/check", exempt)) + self.assertTrue(is_exempt_request_path("/Health/check", exempt)) + + def test_path_with_query_string(self): + """Test paths that include query strings.""" + exempt = {"/search"} + # Query strings should be part of the path being checked + self.assertTrue(is_exempt_request_path("/search/results?q=test", exempt)) + + def test_overlapping_prefixes(self): + """Test with overlapping exempt prefixes.""" + exempt = {"/api", "/api/v1"} + self.assertTrue(is_exempt_request_path("/api/users", exempt)) + self.assertTrue(is_exempt_request_path("/api/v1/users", exempt)) + + def test_single_character_prefix(self): + """Test single character prefix.""" + exempt = {"/a"} + self.assertTrue(is_exempt_request_path("/a/b/c", exempt)) + self.assertFalse(is_exempt_request_path("/b/c", exempt)) + + def test_path_with_dots(self): + """Test paths with dots (e.g., file extensions).""" + exempt = {"/static"} + self.assertTrue(is_exempt_request_path("/static/images/logo.png", exempt)) + self.assertTrue(is_exempt_request_path("/static/css/style.css", exempt)) + + def test_path_with_unicode(self): + """Test paths with Unicode characters.""" + exempt = {"/api"} + self.assertTrue(is_exempt_request_path("/api/用户", exempt)) + exempt_unicode = {"/用户"} + self.assertTrue(is_exempt_request_path("/用户/profile", exempt_unicode)) + + def test_multiple_trailing_slashes(self): + """Test exempt paths with multiple trailing slashes.""" + exempt = {"/health//"} + # After rstrip('/'), becomes "/health/" + self.assertTrue(is_exempt_request_path("/health/check", exempt)) + + def test_path_separator_edge_case(self): + """Test that the path separator logic works correctly.""" + exempt = {"/app"} + # These should NOT match + self.assertFalse(is_exempt_request_path("/application", exempt)) + self.assertFalse(is_exempt_request_path("/app-admin", exempt)) + self.assertFalse(is_exempt_request_path("/apps", exempt)) + # This SHOULD match + self.assertTrue(is_exempt_request_path("/app/status", exempt)) + + def test_real_world_health_check_paths(self): + """Test real-world health check endpoint patterns.""" + exempt = {"/health", "/healthz", "/_health"} + self.assertTrue(is_exempt_request_path("/health/live", exempt)) + self.assertTrue(is_exempt_request_path("/health/ready", exempt)) + self.assertTrue(is_exempt_request_path("/healthz/status", exempt)) + self.assertTrue(is_exempt_request_path("/_health/check", exempt)) + self.assertFalse(is_exempt_request_path("/api/health", exempt)) + + def test_real_world_monitoring_paths(self): + """Test real-world monitoring endpoint patterns.""" + exempt = {"/metrics", "/status", "/actuator"} + self.assertTrue(is_exempt_request_path("/metrics/prometheus", exempt)) + self.assertTrue(is_exempt_request_path("/status/app", exempt)) + self.assertTrue(is_exempt_request_path("/actuator/health", exempt)) + self.assertFalse(is_exempt_request_path("/api/metrics", exempt)) + + def test_long_nested_paths(self): + """Test deeply nested paths with multiple levels.""" + exempt = {"/api"} + # Test very long nested paths + self.assertTrue(is_exempt_request_path("/api/v1/organizations/123/projects/456/resources/789/items", exempt)) + self.assertTrue(is_exempt_request_path("/api/internal/admin/users/settings/preferences/notifications", exempt)) + + exempt_nested = {"/admin/dashboard"} + self.assertTrue(is_exempt_request_path( + "/admin/dashboard/analytics/reports/monthly/2024/november/summary", + exempt_nested + )) + + # Should not match similar but non-matching long paths + self.assertFalse(is_exempt_request_path("/api-v1/dashboard-v2/analytics", exempt_nested))