Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# 1.6.7
- Added optional substring exempt path matching when `exempt_prefix_match` is `True`

# 1.6.6
- Support long-lived connections in ASGI middleware

Expand Down
8 changes: 6 additions & 2 deletions mauth_client/middlewares/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
)
from mauth_client.signable import RequestSignable
from mauth_client.signed import Signed
from mauth_client.utils import decode
from mauth_client.utils import decode, is_exempt_request_path

logger = logging.getLogger("mauth_asgi")


class MAuthASGIMiddleware:
def __init__(self, app: ASGI3Application, exempt: Optional[set] = None) -> None:
def __init__(self, app: ASGI3Application, exempt: Optional[set] = None, exempt_prefix_match: bool = False) -> None:
self._validate_configs()
self.app = app
self.exempt = exempt.copy() if exempt else set()
self.exempt_prefix_match = exempt_prefix_match

async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
Expand All @@ -40,6 +41,9 @@ async def __call__(
if path in self.exempt:
return await self.app(scope, receive, send)

if self.exempt_prefix_match and is_exempt_request_path(path, self.exempt):
return await self.app(scope, receive, send)

query_string = scope["query_string"]
url = f"{path}?{decode(query_string)}" if query_string else path
headers = {decode(k): decode(v) for k, v in scope["headers"]}
Expand Down
7 changes: 6 additions & 1 deletion mauth_client/middlewares/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@

from mauth_client.signable import RequestSignable
from mauth_client.signed import Signed
from mauth_client.utils import is_exempt_request_path

logger = logging.getLogger("mauth_wsgi")


class MAuthWSGIMiddleware:
def __init__(self, app, exempt=None):
def __init__(self, app, exempt=None, exempt_prefix_match=False):
self._validate_configs()
self.app = app
self.exempt = exempt.copy() if exempt else set()
self.exempt_prefix_match = exempt_prefix_match

def __call__(self, environ, start_response):
path = environ.get("PATH_INFO", "")

if path in self.exempt:
return self.app(environ, start_response)

if self.exempt_prefix_match and is_exempt_request_path(path, self.exempt):
return self.app(environ, start_response)

signable = RequestSignable(
method=environ["REQUEST_METHOD"],
url=self._extract_url(environ),
Expand Down
8 changes: 8 additions & 0 deletions mauth_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,11 @@ def decode(byte_string: bytes) -> str:
except UnicodeDecodeError:
encoding = charset_normalizer.detect(byte_string)["encoding"]
return byte_string.decode(encoding)

def is_exempt_request_path(path: str, exempt: set) -> bool:
for exempt_path in exempt:
# Exact match or prefix match with path separator
# For instance this prevents /api matching /api-admin or /app_status matching /app_status_admin
if path.startswith(exempt_path.rstrip('/') + '/'):
return True
return False
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mauth-client"
version = "1.6.6"
version = "1.6.7"
description = "MAuth Client for Python"
repository = "https://github.com/mdsol/mauth-client-python"
authors = ["Medidata Solutions <[email protected]>"]
Expand Down
76 changes: 76 additions & 0 deletions tests/middlewares/asgi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,79 @@ async def mock_receive():
self.assertEqual(call_order[0], ("body", "http.request"))
self.assertEqual(call_order[1], ("disconnect", "http.disconnect"))
self.assertEqual(receive_calls, 2) # Called once for auth, once from app


class TestMAuthASGIMiddlewareWithPrefixMatch(unittest.TestCase):
def setUp(self):
self.app_uuid = str(uuid4())
Config.APP_UUID = self.app_uuid
Config.MAUTH_URL = "https://mauth.com"
Config.MAUTH_API_VERSION = "v1"
Config.PRIVATE_KEY = "key"

self.app = FastAPI()
self.app.add_middleware(
MAuthASGIMiddleware,
exempt={"/health", "/metrics"},
exempt_prefix_match=True
)

@self.app.get("/")
async def root():
return {"msg": "authenticated"}

@self.app.get("/health")
async def health_exact():
return {"msg": "exact health"}

@self.app.get("/health/live")
async def health_live():
return {"msg": "health live"}

@self.app.get("/health/ready")
async def health_ready():
return {"msg": "health ready"}

@self.app.get("/metrics/prometheus")
async def metrics():
return {"msg": "metrics"}

@self.app.get("/api/health")
async def api_health():
return {"msg": "api health"}

self.client = TestClient(self.app)

def test_prefix_match_allows_nested_paths(self):
"""Test that nested paths under exempt prefix are allowed"""
response = self.client.get("/health/live")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "health live"})

response = self.client.get("/health/ready")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "health ready"})

response = self.client.get("/metrics/prometheus")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "metrics"})

def test_prefix_match_blocks_similar_paths(self):
"""Test that similar but non-matching paths are still blocked"""
response = self.client.get("/api/health")
self.assertEqual(response.status_code, 401)

def test_prefix_match_allows_exact_match_in_exempt_set(self):
"""Test that exact match in exempt set is allowed (from exact match check)"""
response = self.client.get("/health")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "exact health"})

@patch.object(LocalAuthenticator, "is_authentic")
def test_prefix_match_still_authenticates_non_exempt_paths(self, is_authentic_mock):
"""Test that non-exempt paths still require authentication"""
is_authentic_mock.return_value = (True, 200, "")

response = self.client.get("/")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"msg": "authenticated"})
76 changes: 76 additions & 0 deletions tests/middlewares/wsgi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,79 @@ def post_test():

self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, body)


class TestMAuthWSGIMiddlewareWithPrefixMatch(unittest.TestCase):
def setUp(self):
self.app_uuid = str(uuid4())
Config.APP_UUID = self.app_uuid
Config.MAUTH_URL = "https://mauth.com"
Config.MAUTH_API_VERSION = "v1"
Config.PRIVATE_KEY = "key"

self.app = Flask("Test App")
self.app.wsgi_app = MAuthWSGIMiddleware(
self.app.wsgi_app,
exempt={"/health", "/metrics"},
exempt_prefix_match=True
)

@self.app.get("/")
def root():
return "authenticated!"

@self.app.get("/health")
def health_exact():
return "exact health"

@self.app.get("/health/live")
def health_live():
return "health live"

@self.app.get("/health/ready")
def health_ready():
return "health ready"

@self.app.get("/metrics/prometheus")
def metrics():
return "metrics"

@self.app.get("/api/health")
def api_health():
return "api health"

self.client = self.app.test_client()

def test_prefix_match_allows_nested_paths(self):
"""Test that nested paths under exempt prefix are allowed"""
response = self.client.get("/health/live")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "health live")

response = self.client.get("/health/ready")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "health ready")

response = self.client.get("/metrics/prometheus")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "metrics")

def test_prefix_match_blocks_similar_paths(self):
"""Test that similar but non-matching paths are still blocked"""
response = self.client.get("/api/health")
self.assertEqual(response.status_code, 401)

def test_prefix_match_allows_exact_match_in_exempt_set(self):
"""Test that exact match in exempt set is allowed (from exact match check)"""
response = self.client.get("/health")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "exact health")

@patch.object(LocalAuthenticator, "is_authentic")
def test_prefix_match_still_authenticates_non_exempt_paths(self, is_authentic_mock):
"""Test that non-exempt paths still require authentication"""
is_authentic_mock.return_value = (True, 200, "")

response = self.client.get("/")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_data(as_text=True), "authenticated!")
Loading
Loading