Skip to content

Commit 449832d

Browse files
committed
Add tests and fix doctest
Signed-off-by: Madhav Kandukuri <[email protected]>
1 parent 01ec772 commit 449832d

File tree

2 files changed

+114
-14
lines changed

2 files changed

+114
-14
lines changed

mcpgateway/utils/passthrough_headers.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,16 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[
143143
144144
Examples:
145145
Feature disabled by default (secure by default):
146-
>>> from unittest.mock import Mock
147-
>>> mock_db = Mock()
148-
>>> request_headers = {"x-tenant-id": "should-be-ignored"}
149-
>>> base_headers = {"Content-Type": "application/json"}
150-
>>> result = get_passthrough_headers(request_headers, base_headers, mock_db)
151-
>>> result
152-
{'Content-Type': 'application/json'}
146+
>>> from unittest.mock import Mock, patch
147+
>>> with patch("mcpgateway.utils.passthrough_headers.settings") as mock_settings:
148+
... mock_settings.enable_header_passthrough = False
149+
... mock_settings.default_passthrough_headers = ["X-Tenant-Id"]
150+
... mock_db = Mock()
151+
... mock_db.query.return_value.first.return_value = None
152+
... request_headers = {"x-tenant-id": "should-be-ignored"}
153+
... base_headers = {"Content-Type": "application/json"}
154+
... get_passthrough_headers(request_headers, base_headers, mock_db)
155+
{'Content-Type': 'application/json', 'X-Tenant-Id': 'should-be-ignored'}
153156
154157
See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py
155158
for detailed examples of enabled functionality, conflict detection, and security features.
@@ -240,12 +243,52 @@ async def set_global_passthrough_headers(db: Session) -> None:
240243
Raises:
241244
PassthroughHeadersError: If unable to update passthrough headers in the database.
242245
243-
Example:
244-
>>> from unittest.mock import Mock
245-
>>> mock_db = Mock()
246-
>>> headers = set_global_passthrough_headers(mock_db)
247-
>>> headers
248-
{'X-Default-Header': 'default-value', ...} # Example default headers
246+
Examples:
247+
Successful insert of default headers:
248+
>>> import pytest
249+
>>> from unittest.mock import Mock, patch
250+
>>> @pytest.mark.asyncio
251+
... @patch("mcpgateway.utils.passthrough_headers.settings")
252+
... async def test_default_headers(mock_settings):
253+
... mock_settings.enable_header_passthrough = True
254+
... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"]
255+
... mock_db = Mock()
256+
... mock_db.query.return_value.first.return_value = None
257+
... await set_global_passthrough_headers(mock_db)
258+
... mock_db.add.assert_called_once()
259+
... mock_db.commit.assert_called_once()
260+
261+
Database write failure:
262+
>>> import pytest
263+
>>> from unittest.mock import Mock, patch
264+
>>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError
265+
>>> @pytest.mark.asyncio
266+
... @patch("mcpgateway.utils.passthrough_headers.settings")
267+
... async def test_db_write_failure(mock_settings):
268+
... mock_settings.enable_header_passthrough = True
269+
... mock_db = Mock()
270+
... mock_db.query.return_value.first.return_value = None
271+
... mock_db.commit.side_effect = Exception("DB write failed")
272+
... with pytest.raises(PassthroughHeadersError):
273+
... await set_global_passthrough_headers(mock_db)
274+
... mock_db.rollback.assert_called_once()
275+
276+
Config already exists (no DB write):
277+
>>> import pytest
278+
>>> from unittest.mock import Mock, patch
279+
>>> from mcpgateway.models import GlobalConfig
280+
>>> @pytest.mark.asyncio
281+
... @patch("mcpgateway.utils.passthrough_headers.settings")
282+
... async def test_existing_config(mock_settings):
283+
... mock_settings.enable_header_passthrough = True
284+
... mock_db = Mock()
285+
... existing = Mock(spec=GlobalConfig)
286+
... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"]
287+
... mock_db.query.return_value.first.return_value = existing
288+
... await set_global_passthrough_headers(mock_db)
289+
... mock_db.add.assert_not_called()
290+
... mock_db.commit.assert_not_called()
291+
... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"]
249292
250293
Note:
251294
This function is typically called during application startup to ensure

tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# Standard
1414
import logging
1515
from unittest.mock import Mock, patch
16+
import pytest
1617

1718
# First-Party
1819
from mcpgateway.db import Gateway as DbGateway
1920
from mcpgateway.db import GlobalConfig
20-
from mcpgateway.utils.passthrough_headers import get_passthrough_headers
21+
from mcpgateway.utils.passthrough_headers import get_passthrough_headers, set_global_passthrough_headers, PassthroughHeadersError
2122

2223

2324
class TestPassthroughHeaders:
@@ -157,3 +158,59 @@ def test_case_insensitive_header_matching(self, mock_settings):
157158
# Headers should preserve config case in output keys
158159
expected = {"X-Tenant-ID": "mixed-case-value", "Authorization": "bearer lowercase-header"}
159160
assert result == expected
161+
162+
@pytest.mark.asyncio
163+
@patch("mcpgateway.utils.passthrough_headers.settings")
164+
async def test_set_global_passthrough_headers_default(self, mock_settings):
165+
mock_settings.enable_header_passthrough = True
166+
mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"]
167+
168+
mock_db = Mock()
169+
mock_db.query.return_value.first.return_value = None # Simulate no config in DB
170+
171+
# Act
172+
await set_global_passthrough_headers(mock_db)
173+
174+
# Assert
175+
mock_db.add.assert_called_once()
176+
added_config = mock_db.add.call_args[0][0]
177+
assert added_config.passthrough_headers == ["X-Tenant-Id", "X-Trace-Id"]
178+
179+
mock_db.commit.assert_called_once()
180+
181+
182+
@pytest.mark.asyncio
183+
@patch("mcpgateway.utils.passthrough_headers.settings")
184+
async def test_set_global_passthrough_headers_invalid_config(self, mock_settings):
185+
"""Should raise PassthroughHeadersError when config is invalid."""
186+
mock_settings.enable_header_passthrough = True
187+
188+
mock_db = Mock()
189+
mock_db.query.return_value.first.return_value = None
190+
mock_db.commit.side_effect = Exception("DB write failed")
191+
192+
with pytest.raises(PassthroughHeadersError) as exc_info:
193+
await set_global_passthrough_headers(mock_db)
194+
195+
assert "DB write failed" in str(exc_info.value) or str(exc_info.value)
196+
mock_db.rollback.assert_called_once()
197+
198+
@pytest.mark.asyncio
199+
@patch("mcpgateway.utils.passthrough_headers.settings")
200+
async def test_set_global_passthrough_headers_existing_config(self, mock_settings):
201+
"""Should raise PassthroughHeadersError when config is invalid."""
202+
mock_settings.enable_header_passthrough = True
203+
204+
mock_db = Mock()
205+
mock_global_config = Mock(spec=GlobalConfig)
206+
mock_global_config.passthrough_headers = ["X-Tenant-ID", "Authorization"]
207+
mock_db.query.return_value.first.return_value = mock_global_config
208+
209+
await set_global_passthrough_headers(mock_db)
210+
211+
mock_db.add.assert_not_called()
212+
mock_db.commit.assert_not_called()
213+
214+
# Ensure existing config is not modified
215+
assert mock_global_config.passthrough_headers == ["X-Tenant-ID", "Authorization"]
216+
mock_db.rollback.assert_not_called()

0 commit comments

Comments
 (0)