Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# 1.6.7
- Added optional substring exempt path matching when `exempt_prefix_match` is `True`

# 1.6.6
- Support long-lived connections in ASGI middleware

Expand Down
8 changes: 6 additions & 2 deletions mauth_client/middlewares/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
)
from mauth_client.signable import RequestSignable
from mauth_client.signed import Signed
from mauth_client.utils import decode
from mauth_client.utils import decode, is_exempt_request_path

logger = logging.getLogger("mauth_asgi")


class MAuthASGIMiddleware:
def __init__(self, app: ASGI3Application, exempt: Optional[set] = None) -> None:
def __init__(self, app: ASGI3Application, exempt: Optional[set] = None, exempt_prefix_match: bool = False) -> None:
self._validate_configs()
self.app = app
self.exempt = exempt.copy() if exempt else set()
self.exempt_prefix_match = exempt_prefix_match

async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
Expand All @@ -40,6 +41,9 @@ async def __call__(
if path in self.exempt:
return await self.app(scope, receive, send)

if self.exempt_prefix_match and is_exempt_request_path(path, self.exempt):
return await self.app(scope, receive, send)

query_string = scope["query_string"]
url = f"{path}?{decode(query_string)}" if query_string else path
headers = {decode(k): decode(v) for k, v in scope["headers"]}
Expand Down
7 changes: 6 additions & 1 deletion mauth_client/middlewares/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@

from mauth_client.signable import RequestSignable
from mauth_client.signed import Signed
from mauth_client.utils import is_exempt_request_path

logger = logging.getLogger("mauth_wsgi")


class MAuthWSGIMiddleware:
def __init__(self, app, exempt=None):
def __init__(self, app, exempt=None, exempt_prefix_match=False):
self._validate_configs()
self.app = app
self.exempt = exempt.copy() if exempt else set()
self.exempt_prefix_match = exempt_prefix_match

def __call__(self, environ, start_response):
path = environ.get("PATH_INFO", "")

if path in self.exempt:
return self.app(environ, start_response)

if self.exempt_prefix_match and is_exempt_request_path(path, self.exempt):
return self.app(environ, start_response)

signable = RequestSignable(
method=environ["REQUEST_METHOD"],
url=self._extract_url(environ),
Expand Down
33 changes: 33 additions & 0 deletions mauth_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,36 @@ def decode(byte_string: bytes) -> str:
except UnicodeDecodeError:
encoding = charset_normalizer.detect(byte_string)["encoding"]
return byte_string.decode(encoding)


def is_exempt_request_path(path: str, exempt: set) -> bool:
"""
Check if a request path should be exempt from authentication based on prefix matching.

This function performs prefix matching with path separator boundary checking to prevent
false positives. A path matches an exempt prefix only if it starts with the exempt path
followed by a path separator ('/').

:param str path: The request path to check (e.g., '/health/live', '/api/users')
:param set exempt: Set of exempt path prefixes (e.g., {'/health', '/metrics'})
:return: True if the path matches any exempt prefix, False otherwise
:rtype: bool

Examples:
Matching cases (returns True):
- path='/health/live', exempt={'/health'} -> True
- path='/health/ready', exempt={'/health'} -> True
- path='/metrics/prometheus', exempt={'/metrics'} -> True

Non-matching cases (returns False):
- path='/health', exempt={'/health'} -> False (exact match without trailing slash)
- path='/api-admin', exempt={'/api'} -> False (not a path separator boundary)
- path='/app_status_admin', exempt={'/app_status'} -> False (underscore, not separator)
- path='/healthcare', exempt={'/health'} -> False (different path)
"""
for exempt_path in exempt:
# Exact match or prefix match with path separator
# For instance this prevents /api matching /api-admin or /app_status matching /app_status_admin
if path.startswith(exempt_path.rstrip('/') + '/'):
return True
return False
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mauth-client"
version = "1.6.6"
version = "1.6.7"
description = "MAuth Client for Python"
repository = "https://github.com/mdsol/mauth-client-python"
authors = ["Medidata Solutions <[email protected]>"]
Expand Down
76 changes: 76 additions & 0 deletions tests/middlewares/asgi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,79 @@ async def mock_receive():
self.assertEqual(call_order[0], ("body", "http.request"))
self.assertEqual(call_order[1], ("disconnect", "http.disconnect"))
self.assertEqual(receive_calls, 2) # Called once for auth, once from app


class TestMAuthASGIMiddlewareWithPrefixMatch(unittest.TestCase):
def setUp(self):
self.app_uuid = str(uuid4())
Config.APP_UUID = self.app_uuid
Config.MAUTH_URL = "https://mauth.com"
Config.MAUTH_API_VERSION = "v1"
Config.PRIVATE_KEY = "key"

self.app = FastAPI()
self.app.add_middleware(
MAuthASGIMiddleware,
exempt={"/health", "/metrics"},
exempt_prefix_match=True
)

@self.app.get("/")
async def root():
return {"msg": "authenticated"}

@self.app.get("/health")
async def health_exact():
return {"msg": "exact health"}

@self.app.get("/health/live")
async def health_live():
return {"msg": "health live"}

@self.app.get("/health/ready")
async def health_ready():
return {"msg": "health ready"}

@self.app.get("/metrics/prometheus")
async def metrics():
return {"msg": "metrics"}

@self.app.get("/api/health")
async def api_health():
return {"msg": "api health"}

self.client = TestClient(self.app)

def test_prefix_match_allows_nested_paths(self):
"""Test that nested paths under exempt prefix are allowed"""
response = self.client.get("/health/live")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "health live"})

response = self.client.get("/health/ready")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "health ready"})

response = self.client.get("/metrics/prometheus")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "metrics"})

def test_prefix_match_blocks_similar_paths(self):
"""Test that similar but non-matching paths are still blocked"""
response = self.client.get("/api/health")
self.assertEqual(response.status_code, 401)

def test_prefix_match_allows_exact_match_in_exempt_set(self):
"""Test that exact match in exempt set is allowed (from exact match check)"""
response = self.client.get("/health")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "exact health"})

@patch.object(LocalAuthenticator, "is_authentic")
def test_prefix_match_still_authenticates_non_exempt_paths(self, is_authentic_mock):
"""Test that non-exempt paths still require authentication"""
is_authentic_mock.return_value = (True, 200, "")

response = self.client.get("/")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "authenticated"})
76 changes: 76 additions & 0 deletions tests/middlewares/wsgi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,79 @@ def post_test():

self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, body)


class TestMAuthWSGIMiddlewareWithPrefixMatch(unittest.TestCase):
def setUp(self):
self.app_uuid = str(uuid4())
Config.APP_UUID = self.app_uuid
Config.MAUTH_URL = "https://mauth.com"
Config.MAUTH_API_VERSION = "v1"
Config.PRIVATE_KEY = "key"

self.app = Flask("Test App")
self.app.wsgi_app = MAuthWSGIMiddleware(
self.app.wsgi_app,
exempt={"/health", "/metrics"},
exempt_prefix_match=True
)

@self.app.get("/")
def root():
return "authenticated!"

@self.app.get("/health")
def health_exact():
return "exact health"

@self.app.get("/health/live")
def health_live():
return "health live"

@self.app.get("/health/ready")
def health_ready():
return "health ready"

@self.app.get("/metrics/prometheus")
def metrics():
return "metrics"

@self.app.get("/api/health")
def api_health():
return "api health"

self.client = self.app.test_client()

def test_prefix_match_allows_nested_paths(self):
"""Test that nested paths under exempt prefix are allowed"""
response = self.client.get("/health/live")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "health live")

response = self.client.get("/health/ready")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "health ready")

response = self.client.get("/metrics/prometheus")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "metrics")

def test_prefix_match_blocks_similar_paths(self):
"""Test that similar but non-matching paths are still blocked"""
response = self.client.get("/api/health")
self.assertEqual(response.status_code, 401)

def test_prefix_match_allows_exact_match_in_exempt_set(self):
"""Test that exact match in exempt set is allowed (from exact match check)"""
response = self.client.get("/health")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "exact health")

@patch.object(LocalAuthenticator, "is_authentic")
def test_prefix_match_still_authenticates_non_exempt_paths(self, is_authentic_mock):
"""Test that non-exempt paths still require authentication"""
is_authentic_mock.return_value = (True, 200, "")

response = self.client.get("/")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "authenticated!")
Loading
Loading