diff --git a/.env.example b/.env.example index 7c2cd9f0..ed36af59 100644 --- a/.env.example +++ b/.env.example @@ -80,7 +80,7 @@ CELERY_ENABLE_UTC=true CELERY_TIMEZONE=Asia/Kolkata -# Callback Timeouts (in seconds) +# Callback Timeouts and size limit(in seconds and MB respectively) CALLBACK_CONNECT_TIMEOUT = 3 CALLBACK_READ_TIMEOUT = 10 diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 78a40ae7..ed66cd04 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -26,7 +26,7 @@ DeletionRequest, CollectionPublic, ) -from app.utils import APIResponse, load_description +from app.utils import APIResponse, load_description, validate_callback_url from app.services.collections import ( create_collection as create_service, delete_collection as delete_service, @@ -81,6 +81,9 @@ def create_collection( current_user: CurrentUserOrgProject, request: CreationRequest, ): + if request.callback_url: + validate_callback_url(str(request.callback_url)) + collection_job_crud = CollectionJobCrud(session, current_user.project_id) collection_job = collection_job_crud.create( CollectionJobCreate( @@ -130,6 +133,9 @@ def delete_collection( collection_id: UUID = FastPath(description="Collection to delete"), request: CallbackRequest | None = Body(default=None), ): + if request and request.callback_url: + validate_callback_url(str(request.callback_url)) + _ = CollectionCrud(session, current_user.project_id).read_one(collection_id) deletion_request = DeletionRequest( diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index ec046aac..27671ef3 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -10,6 +10,7 @@ Query, UploadFile, ) +from pydantic import HttpUrl from fastapi import Path as FastPath from app.api.deps import CurrentUserOrgProject, SessionDep @@ -32,7 +33,12 @@ build_document_schema, build_document_schemas, ) -from app.utils import APIResponse, get_openai_client, load_description +from app.utils import ( + APIResponse, + get_openai_client, + load_description, + validate_callback_url, +) logger = logging.getLogger(__name__) @@ -111,6 +117,9 @@ async def upload_doc( callback_url: str | None = Form(None, description="URL to call to report doc transformation status"), ): + if callback_url: + validate_callback_url(callback_url) + source_format, actual_transformer = pre_transform_validation( src_filename=src.filename, target_format=target_format, diff --git a/backend/app/api/routes/llm.py b/backend/app/api/routes/llm.py index c443941c..e244b225 100644 --- a/backend/app/api/routes/llm.py +++ b/backend/app/api/routes/llm.py @@ -5,7 +5,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.models import LLMCallRequest, LLMCallResponse, Message from app.services.llm.jobs import start_job -from app.utils import APIResponse, load_description +from app.utils import APIResponse, validate_callback_url, load_description logger = logging.getLogger(__name__) @@ -45,6 +45,9 @@ def llm_call( project_id = _current_user.project.id organization_id = _current_user.organization.id + if request.callback_url: + validate_callback_url(str(request.callback_url)) + start_job( db=_session, request=request, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 515874af..d318ce98 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -118,7 +118,7 @@ def AWS_S3_BUCKET(self) -> str: CELERY_ENABLE_UTC: bool = True CELERY_TIMEZONE: str = "UTC" - # callback timeouts + # callback timeouts and limits CALLBACK_CONNECT_TIMEOUT: int = 3 CALLBACK_READ_TIMEOUT: int = 10 diff --git a/backend/app/core/util.py b/backend/app/core/util.py index 41f2eb37..56a05dbc 100644 --- a/backend/app/core/util.py +++ b/backend/app/core/util.py @@ -2,8 +2,6 @@ from datetime import datetime, timezone from fastapi import HTTPException -from requests import Session, RequestException -from pydantic import BaseModel, HttpUrl from openai import OpenAI @@ -24,19 +22,6 @@ def raise_from_unknown(error: Exception, status_code=500): raise HTTPException(status_code=status_code, detail=str(error)) -def post_callback(url: HttpUrl, payload: BaseModel): - errno = 0 - with Session() as session: - response = session.post(str(url), json=payload.model_dump()) - try: - response.raise_for_status() - except RequestException as err: - logger.warning(f"Callback failure: {err}") - errno += 1 - - return not errno - - def configure_openai(credentials: dict) -> tuple[OpenAI, bool]: """ Configure OpenAI client with the provided credentials. diff --git a/backend/app/tests/utils/test_callback_ssrf.py b/backend/app/tests/utils/test_callback_ssrf.py new file mode 100644 index 00000000..3f8219ae --- /dev/null +++ b/backend/app/tests/utils/test_callback_ssrf.py @@ -0,0 +1,318 @@ +"""Tests for callback SSRF protection in utils.py""" + +import pytest +from unittest.mock import patch, MagicMock +import socket +import requests + +from app.utils import _is_private_ip, validate_callback_url, send_callback + + +class TestIsPrivateIP: + """Test suite for _is_private_ip function.""" + + def test_private_ipv4_addresses(self): + """Test that private IPv4 addresses are correctly identified.""" + private_ips = [ + "10.0.0.1", + "10.255.255.255", + "172.16.0.1", + "172.31.255.255", + "192.168.0.1", + "192.168.255.255", + ] + for ip in private_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is True, f"{ip} should be identified as private" + assert reason == "private", f"{ip} should have reason 'private'" + + def test_localhost_addresses(self): + """Test that localhost/loopback addresses are blocked.""" + localhost_ips = [ + "127.0.0.1", + "127.0.0.2", + "127.255.255.255", + "::1", + ] + for ip in localhost_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is True, f"{ip} should be identified as loopback" + assert ( + reason == "loopback/localhost" + ), f"{ip} should have reason 'loopback/localhost'" + + def test_link_local_addresses(self): + """Test that link-local addresses are blocked.""" + link_local_ips = [ + "169.254.0.1", + "169.254.169.254", + "169.254.255.255", + ] + for ip in link_local_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is True, f"{ip} should be identified as link-local" + assert reason == "link-local", f"{ip} should have reason 'link-local'" + + def test_multicast_addresses(self): + """Test that multicast addresses are blocked.""" + multicast_ips = [ + "224.0.0.1", + "239.255.255.255", + ] + for ip in multicast_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is True, f"{ip} should be identified as multicast" + assert reason == "multicast", f"{ip} should have reason 'multicast'" + + def test_public_ipv4_addresses(self): + """Test that public IPv4 addresses are not blocked.""" + public_ips = [ + "8.8.8.8", + "1.1.1.1", + "93.184.216.34", + "151.101.1.140", + ] + for ip in public_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is False, f"{ip} should be identified as public" + assert reason == "", f"{ip} should have empty reason" + + def test_public_ipv6_addresses(self): + """Test that public IPv6 addresses are not blocked.""" + public_ipv6 = [ + "2001:4860:4860::8888", + "2606:4700:4700::1111", + ] + for ip in public_ipv6: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is False, f"{ip} should be identified as public" + assert reason == "", f"{ip} should have empty reason" + + def test_invalid_ip_addresses(self): + """Test that invalid IP addresses return False.""" + invalid_ips = [ + "not_an_ip", + "999.999.999.999", + "example.com", + ] + for ip in invalid_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is False, f"{ip} should return False" + assert reason == "", f"{ip} should have empty reason" + + +class TestValidateCallbackURL: + """Test suite for validate_callback_url function.""" + + def test_reject_non_https_schemes(self): + """Test that non-HTTPS URL schemes are rejected.""" + non_https_urls = [ + "http://example.com/callback", + "ftp://example.com/callback", + "file:///etc/passwd", + ] + for url in non_https_urls: + with pytest.raises(ValueError, match="Only HTTPS URLs are allowed"): + validate_callback_url(url) + + @patch("socket.getaddrinfo") + def test_reject_localhost_by_name(self, mock_getaddrinfo): + """Test that localhost is rejected.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("127.0.0.1", 443)) + ] + + with pytest.raises(ValueError, match="loopback/localhost IP address"): + validate_callback_url("https://localhost/callback") + + @patch("socket.getaddrinfo") + def test_reject_private_ip_addresses(self, mock_getaddrinfo): + """Test that private IPs in all RFC 1918 ranges are rejected.""" + private_ips = [ + ("10.0.0.1", "https://internal.company.com/callback"), + ("192.168.1.1", "https://router.local/callback"), + ("172.16.0.1", "https://internal-api.local/callback"), + ] + + for ip, url in private_ips: + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 443)) + ] + + with pytest.raises(ValueError, match="private IP address"): + validate_callback_url(url) + + @patch("socket.getaddrinfo") + def test_reject_link_local_addresses(self, mock_getaddrinfo): + """Test that link-local addresses are rejected (including cloud metadata endpoints).""" + link_local_ips = [ + ( + "169.254.169.254", + "https://metadata.aws/callback", + ), # AWS metadata endpoint + ("169.254.0.1", "https://link-local.example/callback"), + ] + + for ip, url in link_local_ips: + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 443)) + ] + + with pytest.raises(ValueError, match="link-local IP address"): + validate_callback_url(url) + + @patch("socket.getaddrinfo") + def test_accept_public_ip_addresses(self, mock_getaddrinfo): + """Test that valid HTTPS URLs with public IP addresses are accepted.""" + public_ips = [ + ("8.8.8.8", "https://api.example.com/callback"), + ("151.101.1.140", "https://webhook.site/unique-id"), + ] + + for ip, url in public_ips: + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 443)) + ] + + validate_callback_url(url) + + def test_reject_url_without_hostname(self): + """Test that URLs without hostname are rejected.""" + with pytest.raises(ValueError, match="URL must have a valid hostname"): + validate_callback_url("https:///callback") + + def test_reject_invalid_url_format(self): + """Test that invalid URL formats are rejected.""" + with pytest.raises(ValueError, match="Only HTTPS URLs are allowed"): + validate_callback_url("not a url at all") + + @patch("socket.getaddrinfo") + def test_check_all_resolved_ips(self, mock_getaddrinfo): + """Test that all resolved IPs are checked (DNS round-robin).""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", 443)), + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.168.1.1", 443)), + ] + + with pytest.raises(ValueError, match="private IP address"): + validate_callback_url("https://malicious-dns.example/callback") + + @patch("socket.getaddrinfo") + def test_ipv6_public_address_accepted(self, mock_getaddrinfo): + """Test that public IPv6 addresses are accepted.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("2001:4860:4860::8888", 443)) + ] + + validate_callback_url("https://ipv6.example.com/callback") + + @patch("socket.getaddrinfo") + def test_ipv6_localhost_rejected(self, mock_getaddrinfo): + """Test that IPv6 localhost is rejected.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("::1", 443)) + ] + + with pytest.raises(ValueError, match="loopback/localhost IP address"): + validate_callback_url("https://localhost6/callback") + + +class TestSendCallback: + """Test suite for send_callback function.""" + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_successful_callback(self, mock_session_class, mock_validate): + """Test successful callback execution.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + result = send_callback( + "https://api.example.com/callback", {"status": "success"} + ) + + assert result is True + mock_session.post.assert_called_once() + assert mock_session.post.call_args[1]["allow_redirects"] is False + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_network_error(self, mock_session_class, mock_validate): + """Test that callback returns False on network errors.""" + mock_session = MagicMock() + mock_session.post.side_effect = requests.RequestException("Connection refused") + mock_session_class.return_value.__enter__.return_value = mock_session + + result = send_callback("https://api.example.com/callback", {"data": "test"}) + + assert result is False + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_http_error(self, mock_session_class, mock_validate): + """Test that callback returns False on HTTP errors.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + result = send_callback("https://api.example.com/callback", {"data": "test"}) + + assert result is False + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_disables_redirects(self, mock_session_class, mock_validate): + """Test that redirects are disabled to prevent redirect-based SSRF.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + send_callback("https://api.example.com/callback", {"data": "test"}) + + call_kwargs = mock_session.post.call_args[1] + assert call_kwargs["allow_redirects"] is False + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_uses_timeout(self, mock_session_class, mock_validate): + """Test that callback uses configured timeouts.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + send_callback("https://api.example.com/callback", {"data": "test"}) + + call_kwargs = mock_session.post.call_args[1] + assert "timeout" in call_kwargs + assert isinstance(call_kwargs["timeout"], tuple) + assert len(call_kwargs["timeout"]) == 2 + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_sends_json_data(self, mock_session_class, mock_validate): + """Test that callback sends data as JSON.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + test_data = {"status": "completed", "result": 42} + send_callback("https://api.example.com/callback", test_data) + + call_kwargs = mock_session.post.call_args[1] + assert "json" in call_kwargs + assert call_kwargs["json"] == test_data diff --git a/backend/app/utils.py b/backend/app/utils.py index 094c3682..78877d35 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -1,10 +1,13 @@ import functools as ft +import ipaddress import logging from dataclasses import dataclass from datetime import datetime, timedelta, timezone from pathlib import Path import requests +import socket from typing import Any, Dict, Generic, Optional, TypeVar +from urllib.parse import urlparse import jwt import emails @@ -262,12 +265,110 @@ def handle_openai_error(e: openai.OpenAIError) -> str: return str(e) -def send_callback(callback_url: str, data: dict): - """Send results to the callback URL (synchronously).""" +def _is_private_ip(ip: str) -> tuple[bool, str]: + """Check if an IP address is private, localhost, or reserved.""" + try: + ip_obj = ipaddress.ip_address(ip) + + checks = [ + (ip_obj.is_loopback, "loopback/localhost"), + (ip_obj.is_link_local, "link-local"), + (ip_obj.is_multicast, "multicast"), + (ip_obj.is_private, "private"), + (ip_obj.is_reserved, "reserved"), + ] + + for is_blocked, reason in checks: + if is_blocked: + return (True, reason) + + return (False, "") + + except ValueError: + return (False, "") + + +def validate_callback_url(url: str) -> None: + """ + Validate callback URL to prevent SSRF attacks. + + Blocks: + - Non-HTTPS URLs + - Private IP addresses (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) + - Localhost/loopback addresses (127.0.0.0/8, ::1) + - Link-local addresses (169.254.0.0/16) + - Cloud metadata endpoints (169.254.169.254) + - Reserved IP ranges + + Args: + url: The callback URL to validate + + Raises: + ValueError: If URL is not allowed + """ + try: + parsed = urlparse(url) + + if parsed.scheme != "https": + raise ValueError( + f"Only HTTPS URLs are allowed for callbacks. Got: {parsed.scheme}" + ) + + if not parsed.hostname: + raise ValueError("URL must have a valid hostname") + + addr_info = socket.getaddrinfo( + parsed.hostname, + parsed.port or 443, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + ) + + for info in addr_info: + ip_address = info[4][0] + is_blocked, reason = _is_private_ip(ip_address) + if is_blocked: + raise ValueError( + f"Callback URL resolves to {reason} IP address: {ip_address}. " + f"This IP type is not allowed for callbacks." + ) + + except ValueError: + raise + except Exception as e: + raise ValueError(f"Error validating callback URL: {str(e)}") from e + + +def send_callback(callback_url: str, data: dict[str, Any]) -> bool: + """ + Send results to the callback URL (synchronously) with SSRF protection. + + Security features: + - HTTPS-only enforcement + - Private IP blocking (RFC 1918) + - Localhost/loopback blocking + - Cloud metadata endpoint blocking + - DNS rebinding protection + - Redirect following disabled + - Strict timeouts + + Args: + callback_url: The HTTPS URL to send the callback to + data: The JSON data to send in the POST request + + Returns: + bool: True if callback succeeded, False otherwise + """ + try: + validate_callback_url(str(callback_url)) + except ValueError as ve: + logger.error(f"[send_callback] Invalid callback URL: {ve}", exc_info=True) + return False + try: with requests.Session() as session: - # uncomment this to run locally without SSL - # session.verify = False + session.trust_env = False # Ignores environment proxies and other implicit settings for SSRF safety + response = session.post( callback_url, json=data, @@ -275,10 +376,14 @@ def send_callback(callback_url: str, data: dict): settings.CALLBACK_CONNECT_TIMEOUT, settings.CALLBACK_READ_TIMEOUT, ), + allow_redirects=False, ) + response.raise_for_status() - logger.info(f"[send_callback] Callback sent successfully to {callback_url}") + + logger.info("[send_callback] Callback sent successfully") return True + except requests.RequestException as e: logger.error(f"[send_callback] Callback failed: {str(e)}", exc_info=True) return False