From d94b4b09d306c472991c1f39d258c5a303e44efc Mon Sep 17 00:00:00 2001 From: pgjones Date: Tue, 26 Dec 2023 10:25:32 +0000 Subject: [PATCH 01/21] Ensure the idle task is stopped on error Otherwise the task will persist and attempt to close an already closed connection. --- src/hypercorn/asyncio/tcp_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index 025ec0a0..e90858b4 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -117,8 +117,8 @@ async def _close(self) -> None: await self.writer.wait_closed() except (BrokenPipeError, ConnectionAbortedError, ConnectionResetError, RuntimeError): pass # Already closed - - await self._stop_idle() + finally: + await self._stop_idle() async def _initiate_server_close(self) -> None: await self.protocol.handle(Closed()) From 0e4117da672c4d9ec09a7802a2c641539a03042c Mon Sep 17 00:00:00 2001 From: pgjones Date: Tue, 26 Dec 2023 10:26:24 +0000 Subject: [PATCH 02/21] Fix latest mypy issues --- src/hypercorn/asyncio/run.py | 2 +- src/hypercorn/utils.py | 2 +- tests/asyncio/test_sanity.py | 2 +- tests/middleware/test_dispatcher.py | 6 +++--- tests/trio/test_sanity.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/hypercorn/asyncio/run.py b/src/hypercorn/asyncio/run.py index 47745383..7c0982d9 100644 --- a/src/hypercorn/asyncio/run.py +++ b/src/hypercorn/asyncio/run.py @@ -65,7 +65,7 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 # Add signal handler may not be implemented on Windows signal.signal(getattr(signal, signal_name), _signal_handler) - shutdown_trigger = signal_event.wait # type: ignore + shutdown_trigger = signal_event.wait lifespan = Lifespan(app, config, loop) diff --git a/src/hypercorn/utils.py b/src/hypercorn/utils.py index 5629ff71..9e3520d7 100644 --- a/src/hypercorn/utils.py +++ b/src/hypercorn/utils.py @@ -185,7 +185,7 @@ def write_pid_file(pid_path: str) -> None: def parse_socket_addr(family: int, address: tuple) -> Optional[Tuple[str, int]]: if family == socket.AF_INET: - return address # type: ignore + return address elif family == socket.AF_INET6: return (address[0], address[1]) else: diff --git a/tests/asyncio/test_sanity.py b/tests/asyncio/test_sanity.py index 287cd06d..2d7cb0bc 100644 --- a/tests/asyncio/test_sanity.py +++ b/tests/asyncio/test_sanity.py @@ -66,7 +66,7 @@ async def test_http1_request(event_loop: asyncio.AbstractEventLoop) -> None: reason=b"", ), h11.Data(data=b"Hello & Goodbye"), - h11.EndOfMessage(headers=[]), # type: ignore + h11.EndOfMessage(headers=[]), ] server.reader.close() # type: ignore await task diff --git a/tests/middleware/test_dispatcher.py b/tests/middleware/test_dispatcher.py index 1c3d7a28..dbb3f43e 100644 --- a/tests/middleware/test_dispatcher.py +++ b/tests/middleware/test_dispatcher.py @@ -36,9 +36,9 @@ async def send(message: dict) -> None: nonlocal sent_events sent_events.append(message) - await app({**http_scope, **{"path": "/api/x/b"}}, None, send) - await app({**http_scope, **{"path": "/api/b"}}, None, send) - await app({**http_scope, **{"path": "/"}}, None, send) + await app({**http_scope, **{"path": "/api/x/b"}}, None, send) # type: ignore + await app({**http_scope, **{"path": "/api/b"}}, None, send) # type: ignore + await app({**http_scope, **{"path": "/"}}, None, send) # type: ignore assert sent_events == [ {"type": "http.response.start", "status": 200, "headers": [(b"content-length", b"7")]}, {"type": "http.response.body", "body": b"apix-/b"}, diff --git a/tests/trio/test_sanity.py b/tests/trio/test_sanity.py index 3828e37b..6d4be8c5 100644 --- a/tests/trio/test_sanity.py +++ b/tests/trio/test_sanity.py @@ -68,7 +68,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: reason=b"", ), h11.Data(data=b"Hello & Goodbye"), - h11.EndOfMessage(headers=[]), # type: ignore + h11.EndOfMessage(headers=[]), ] From 926c4303a7298ce53a772cf6cec9a3da75be35a2 Mon Sep 17 00:00:00 2001 From: pgjones Date: Tue, 26 Dec 2023 18:31:00 +0000 Subject: [PATCH 03/21] Add a max keep alive requests configuration option This will cause HTTP/1 and HTTP/2 requests to close when the limit has been reached. This matches nginx's mitigation against the rapid reset HTTP/2 attack. --- docs/discussion/dos_mitigations.rst | 11 +++++++++++ docs/how_to_guides/configuring.rst | 2 ++ src/hypercorn/config.py | 1 + src/hypercorn/protocol/h11.py | 3 +++ src/hypercorn/protocol/h2.py | 6 ++++++ tests/protocol/test_h11.py | 12 ++++++++++++ tests/protocol/test_h2.py | 23 +++++++++++++++++++++++ 7 files changed, 58 insertions(+) diff --git a/docs/discussion/dos_mitigations.rst b/docs/discussion/dos_mitigations.rst index 358ba987..88cc48bb 100644 --- a/docs/discussion/dos_mitigations.rst +++ b/docs/discussion/dos_mitigations.rst @@ -169,3 +169,14 @@ data that it cannot send to the client. To mitigate this Hypercorn responds to the backpressure and pauses (blocks) the coroutine writing the response. + +Rapid reset +^^^^^^^^^^^ + +This attack works by opening and closing streams in quick succession +in the expectation that this is more costly for the server than the +client. + +To mitigate Hypercorn will only allow a maximum number of requests per +kept-alive connection before closing it. This ensures that cost of the +attack is equally born by the client. diff --git a/docs/how_to_guides/configuring.rst b/docs/how_to_guides/configuring.rst index 26607ba7..d72e4ed7 100644 --- a/docs/how_to_guides/configuring.rst +++ b/docs/how_to_guides/configuring.rst @@ -124,6 +124,8 @@ insecure_bind ``--insecure-bind`` The TCP host/address to See *bind* for formatting options. Care must be taken! See HTTP -> HTTPS redirection docs. +keep_alive_max_requests N/A Maximum number of requests before connection 1000 + is closed. HTTP/1 & HTTP/2 only. keep_alive_timeout ``--keep-alive`` Seconds to keep inactive connections alive 5s before closing. keyfile ``--keyfile`` Path to the SSL key file. diff --git a/src/hypercorn/config.py b/src/hypercorn/config.py index 26f50f00..fdc7a413 100644 --- a/src/hypercorn/config.py +++ b/src/hypercorn/config.py @@ -84,6 +84,7 @@ class Config: include_date_header = True include_server_header = True keep_alive_timeout = 5 * SECONDS + keep_alive_max_requests = 1000 keyfile: Optional[str] = None keyfile_password: Optional[str] = None logconfig: Optional[str] = None diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index 49f8b179..1cbf8778 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -97,6 +97,7 @@ def __init__( h11.SERVER, max_incomplete_event_size=self.config.h11_max_incomplete_size ) self.context = context + self.keep_alive_requests = 0 self.send = send self.server = server self.ssl = ssl @@ -234,6 +235,7 @@ async def _create_stream(self, request: h11.Request) -> None: raw_path=request.target, ) ) + self.keep_alive_requests += 1 async def _send_h11_event(self, event: H11SendableEvent) -> None: try: @@ -264,6 +266,7 @@ async def _maybe_recycle(self) -> None: not self.context.terminated.is_set() and self.connection.our_state is h11.DONE and self.connection.their_state is h11.DONE + and self.keep_alive_requests <= self.config.keep_alive_max_requests ): try: self.connection.start_next_cycle() diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 6e76d493..776902e6 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -109,6 +109,7 @@ def __init__( }, ) + self.keep_alive_requests = 0 self.send = send self.server = server self.ssl = ssl @@ -244,6 +245,9 @@ async def _handle_events(self, events: List[h2.events.Event]) -> None: else: await self._create_stream(event) await self.send(Updated(idle=False)) + + if self.keep_alive_requests > self.config.keep_alive_max_requests: + self.connection.close_connection() elif isinstance(event, h2.events.DataReceived): await self.streams[event.stream_id].handle( Body(stream_id=event.stream_id, data=event.data) @@ -349,6 +353,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: raw_path=raw_path, ) ) + self.keep_alive_requests += 1 async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] @@ -374,6 +379,7 @@ async def _create_server_push( event.headers = request_headers await self._create_stream(event) await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) + self.keep_alive_max_requests += 1 async def _close_stream(self, stream_id: int) -> None: if stream_id in self.streams: diff --git a/tests/protocol/test_h11.py b/tests/protocol/test_h11.py index 86bb00a4..a136fde2 100755 --- a/tests/protocol/test_h11.py +++ b/tests/protocol/test_h11.py @@ -103,6 +103,18 @@ async def test_protocol_send_body(protocol: H11Protocol) -> None: ] +@pytest.mark.asyncio +async def test_protocol_keep_alive_max_requests(protocol: H11Protocol) -> None: + data = b"GET / HTTP/1.1\r\nHost: hypercorn\r\n\r\n" + protocol.config.keep_alive_max_requests = 0 + await protocol.handle(RawData(data=data)) + await protocol.stream_send(Response(stream_id=1, status_code=200, headers=[])) + await protocol.stream_send(EndBody(stream_id=1)) + await protocol.stream_send(StreamClosed(stream_id=1)) + protocol.send.assert_called() # type: ignore + assert protocol.send.call_args_list[3] == call(Closed()) # type: ignore + + @pytest.mark.asyncio @pytest.mark.parametrize("keep_alive, expected", [(True, Updated(idle=True)), (False, Closed())]) async def test_protocol_send_stream_closed( diff --git a/tests/protocol/test_h2.py b/tests/protocol/test_h2.py index c44f39ae..77bcf515 100644 --- a/tests/protocol/test_h2.py +++ b/tests/protocol/test_h2.py @@ -4,6 +4,8 @@ from unittest.mock import call, Mock import pytest +from h2.connection import H2Connection +from h2.events import ConnectionTerminated from hypercorn.asyncio.worker_context import EventWrapper, WorkerContext from hypercorn.config import Config @@ -78,3 +80,24 @@ async def test_protocol_handle_protocol_error() -> None: await protocol.handle(RawData(data=b"broken nonsense\r\n\r\n")) protocol.send.assert_awaited() # type: ignore assert protocol.send.call_args_list == [call(Closed())] # type: ignore + + +@pytest.mark.asyncio +async def test_protocol_keep_alive_max_requests() -> None: + protocol = H2Protocol( + Mock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock() + ) + protocol.config.keep_alive_max_requests = 0 + client = H2Connection() + client.initiate_connection() + headers = [ + (":method", "GET"), + (":path", "/reqinfo"), + (":authority", "hypercorn"), + (":scheme", "https"), + ] + client.send_headers(1, headers, end_stream=True) + await protocol.handle(RawData(data=client.data_to_send())) + protocol.send.assert_awaited() # type: ignore + events = client.receive_data(protocol.send.call_args_list[1].args[0].data) # type: ignore + assert isinstance(events[-1], ConnectionTerminated) From 5a77873ecf0693bdd4fba0baab225864c8d2ae87 Mon Sep 17 00:00:00 2001 From: pgjones Date: Wed, 27 Dec 2023 10:51:26 +0000 Subject: [PATCH 04/21] Correct 926c4303a7298ce53a772cf6cec9a3da75be35a2 --- src/hypercorn/protocol/h2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 776902e6..26048780 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -379,7 +379,7 @@ async def _create_server_push( event.headers = request_headers await self._create_stream(event) await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) - self.keep_alive_max_requests += 1 + self.keep_alive_requests += 1 async def _close_stream(self, stream_id: int) -> None: if stream_id in self.streams: From ebb09a6c606c2a9c4e6e3a2d4c7a27262cdf6573 Mon Sep 17 00:00:00 2001 From: pgjones Date: Wed, 27 Dec 2023 10:36:15 +0000 Subject: [PATCH 05/21] Revert "fix: Autoreload error because reausing old sockets" This reverts commit 4854ffd89e8661213ff20828b7568a9f004803a9. It doesn't fix the issue and creates additional reported issues. --- src/hypercorn/run.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/hypercorn/run.py b/src/hypercorn/run.py index a6d2fb08..563822fb 100644 --- a/src/hypercorn/run.py +++ b/src/hypercorn/run.py @@ -72,9 +72,6 @@ def shutdown(*args: Any) -> None: if config.use_reloader: wait_for_changes(shutdown_event) shutdown_event.set() - # Recreate the sockets to be used again in the next - # iteration of the loop. - sockets = config.create_sockets() else: active = False From 2b0aad3b1fde7362d785eb489f36e152d1deec16 Mon Sep 17 00:00:00 2001 From: Thomas Baker Date: Thu, 16 Nov 2023 16:26:25 -0500 Subject: [PATCH 06/21] Send the hinted error from h11 on RemoteProtocolErrors This properly punches through 431 status codes --- src/hypercorn/protocol/h11.py | 4 ++-- tests/protocol/test_h11.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index 1cbf8778..ec04593a 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -155,9 +155,9 @@ async def _handle_events(self) -> None: try: event = self.connection.next_event() - except h11.RemoteProtocolError: + except h11.RemoteProtocolError as error: if self.connection.our_state in {h11.IDLE, h11.SEND_RESPONSE}: - await self._send_error_response(400) + await self._send_error_response(error.error_status_hint) await self.send(Closed()) break else: diff --git a/tests/protocol/test_h11.py b/tests/protocol/test_h11.py index a136fde2..27c80b56 100755 --- a/tests/protocol/test_h11.py +++ b/tests/protocol/test_h11.py @@ -314,7 +314,7 @@ async def test_protocol_handle_max_incomplete(monkeypatch: MonkeyPatch) -> None: assert protocol.send.call_args_list == [ # type: ignore call( RawData( - data=b"HTTP/1.1 400 \r\ncontent-length: 0\r\nconnection: close\r\n" + data=b"HTTP/1.1 431 \r\ncontent-length: 0\r\nconnection: close\r\n" b"date: Thu, 01 Jan 1970 01:23:20 GMT\r\nserver: hypercorn-h11\r\n\r\n" ) ), From 1f874fc2076541feeacff78f472fdddb01ccc0a7 Mon Sep 17 00:00:00 2001 From: stopdropandrew Date: Mon, 11 Dec 2023 11:43:39 -0800 Subject: [PATCH 07/21] Handle `asyncio.CancelledError` when socket is closed without flushing --- src/hypercorn/asyncio/tcp_server.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index e90858b4..ed9d710f 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -115,7 +115,13 @@ async def _close(self) -> None: try: self.writer.close() await self.writer.wait_closed() - except (BrokenPipeError, ConnectionAbortedError, ConnectionResetError, RuntimeError): + except ( + BrokenPipeError, + ConnectionAbortedError, + ConnectionResetError, + RuntimeError, + asyncio.CancelledError, + ): pass # Already closed finally: await self._stop_idle() From 80fa1940089ad00ab5cc7dea2337c6d4aeec8b33 Mon Sep 17 00:00:00 2001 From: seidnerj Date: Mon, 25 Dec 2023 14:36:54 +0200 Subject: [PATCH 08/21] update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 8ef153fe..c436bdd0 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ docs/reference/source/ dist/ .coverage poetry.lock +.idea/ +.DS_Store From cb443a4a4e0f4ff200cf94b83e52161413ea4501 Mon Sep 17 00:00:00 2001 From: seidnerj Date: Tue, 26 Dec 2023 00:30:27 +0200 Subject: [PATCH 09/21] if any of our subprocesses exits with a non-zero exit code, we should also exit with a non-zero exit code. --- src/hypercorn/__main__.py | 6 +++--- src/hypercorn/run.py | 10 +++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/hypercorn/__main__.py b/src/hypercorn/__main__.py index b3dc0e80..769a5e09 100644 --- a/src/hypercorn/__main__.py +++ b/src/hypercorn/__main__.py @@ -23,7 +23,7 @@ def _load_config(config_path: Optional[str]) -> Config: return Config.from_toml(config_path) -def main(sys_args: Optional[List[str]] = None) -> None: +def main(sys_args: Optional[List[str]] = None) -> int: parser = argparse.ArgumentParser() parser.add_argument( "application", help="The application to dispatch to as path.to.module:instance.path" @@ -284,8 +284,8 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: if len(args.server_names) > 0: config.server_names = args.server_names - run(config) + return run(config) if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/src/hypercorn/run.py b/src/hypercorn/run.py index 563822fb..05ab2391 100644 --- a/src/hypercorn/run.py +++ b/src/hypercorn/run.py @@ -15,7 +15,9 @@ from .utils import load_application, wait_for_changes, write_pid_file -def run(config: Config) -> None: +def run(config: Config) -> int: + exit_code = 0 + if config.pid_path is not None: write_pid_file(config.pid_path) @@ -77,14 +79,20 @@ def shutdown(*args: Any) -> None: for process in processes: process.join() + if process.exitcode != 0: + exit_code = process.exitcode + for process in processes: process.terminate() for sock in sockets.secure_sockets: sock.close() + for sock in sockets.insecure_sockets: sock.close() + return exit_code + def start_processes( config: Config, From 2d2c62bac7b83a8c6766fe3a517f63ff842e5c38 Mon Sep 17 00:00:00 2001 From: pgjones Date: Wed, 27 Dec 2023 17:13:34 +0000 Subject: [PATCH 10/21] Improve WSGI compliance The response body is closed if it has a close method as per PEP 3333. In addition the response headers are only sent when the first response body byte is available to send. Finally, an error is raised if start_response has not been called by the app. --- src/hypercorn/app_wrappers.py | 23 ++++++++-- tests/test_app_wrappers.py | 79 ++++++++++++++++++++++------------- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/src/hypercorn/app_wrappers.py b/src/hypercorn/app_wrappers.py index 769e014b..cfc41cfd 100644 --- a/src/hypercorn/app_wrappers.py +++ b/src/hypercorn/app_wrappers.py @@ -84,6 +84,8 @@ async def handle_http( def run_app(self, environ: dict, send: Callable) -> None: headers: List[Tuple[bytes, bytes]] + headers_sent = False + response_started = False status_code: Optional[int] = None def start_response( @@ -91,7 +93,7 @@ def start_response( response_headers: List[Tuple[str, str]], exc_info: Optional[Exception] = None, ) -> None: - nonlocal headers, status_code + nonlocal headers, response_started, status_code raw, _ = status.split(" ", 1) status_code = int(raw) @@ -99,10 +101,23 @@ def start_response( (name.lower().encode("ascii"), value.encode("ascii")) for name, value in response_headers ] - send({"type": "http.response.start", "status": status_code, "headers": headers}) + response_started = True - for output in self.app(environ, start_response): - send({"type": "http.response.body", "body": output, "more_body": True}) + response_body = self.app(environ, start_response) + + if not response_started: + raise RuntimeError("WSGI app did not call start_response") + + try: + for output in response_body: + if not headers_sent: + send({"type": "http.response.start", "status": status_code, "headers": headers}) + headers_sent = True + + send({"type": "http.response.body", "body": output, "more_body": True}) + finally: + if hasattr(response_body, "close"): + response_body.close() def _build_environ(scope: HTTPScope, body: bytes) -> dict: diff --git a/tests/test_app_wrappers.py b/tests/test_app_wrappers.py index bb7b5897..c68ba0cb 100644 --- a/tests/test_app_wrappers.py +++ b/tests/test_app_wrappers.py @@ -61,8 +61,28 @@ async def _send(message: ASGISendEvent) -> None: ] +async def _run_app(app: WSGIWrapper, scope: HTTPScope, body: bytes = b"") -> List[ASGISendEvent]: + queue: asyncio.Queue = asyncio.Queue() + await queue.put({"type": "http.request", "body": body}) + + messages = [] + + async def _send(message: ASGISendEvent) -> None: + nonlocal messages + messages.append(message) + + event_loop = asyncio.get_running_loop() + + def _call_soon(func: Callable, *args: Any) -> Any: + future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) + return future.result() + + await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + return messages + + @pytest.mark.asyncio -async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_wsgi_asyncio() -> None: app = WSGIWrapper(echo_body, 2**16) scope: HTTPScope = { "http_version": "1.1", @@ -79,20 +99,7 @@ async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None: "server": None, "extensions": {}, } - queue: asyncio.Queue = asyncio.Queue() - await queue.put({"type": "http.request"}) - - messages = [] - - async def _send(message: ASGISendEvent) -> None: - nonlocal messages - messages.append(message) - - def _call_soon(func: Callable, *args: Any) -> Any: - future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) - return future.result() - - await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + messages = await _run_app(app, scope) assert messages == [ { "headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")], @@ -105,7 +112,7 @@ def _call_soon(func: Callable, *args: Any) -> Any: @pytest.mark.asyncio -async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_max_body_size() -> None: app = WSGIWrapper(echo_body, 4) scope: HTTPScope = { "http_version": "1.1", @@ -122,25 +129,39 @@ async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None: "server": None, "extensions": {}, } - queue: asyncio.Queue = asyncio.Queue() - await queue.put({"type": "http.request", "body": b"abcde"}) - messages = [] - - async def _send(message: ASGISendEvent) -> None: - nonlocal messages - messages.append(message) - - def _call_soon(func: Callable, *args: Any) -> Any: - future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) - return future.result() - - await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + messages = await _run_app(app, scope, b"abcde") assert messages == [ {"headers": [], "status": 400, "type": "http.response.start"}, {"body": bytearray(b""), "type": "http.response.body", "more_body": False}, ] +def no_start_response(environ: dict, start_response: Callable) -> List[bytes]: + return [b"result"] + + +@pytest.mark.asyncio +async def test_no_start_response() -> None: + app = WSGIWrapper(no_start_response, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + } + with pytest.raises(RuntimeError): + await _run_app(app, scope) + + def test_build_environ_encoding() -> None: scope: HTTPScope = { "http_version": "1.0", From 4fc0372483210257d28d9e0b5f7746df145449c6 Mon Sep 17 00:00:00 2001 From: pgjones Date: Thu, 28 Dec 2023 11:48:17 +0000 Subject: [PATCH 11/21] Add ProxyFix middleware This allows for Hypercorn to be used behind a proxy with the headers being "fixed" such that the proxy is not present as far as the app is concerned. This makes it easier to write applications that run behind proxies. Note I've defaulted to legacy mode as AWS's load balancers don't support the modern Forwarded header and I assume that makes up a large percentage of real world usage. --- docs/how_to_guides/index.rst | 1 + docs/how_to_guides/proxy_fix.rst | 33 ++++++++++++ src/hypercorn/middleware/__init__.py | 2 + src/hypercorn/middleware/proxy_fix.py | 78 +++++++++++++++++++++++++++ tests/middleware/test_proxy_fix.py | 64 ++++++++++++++++++++++ 5 files changed, 178 insertions(+) create mode 100644 docs/how_to_guides/proxy_fix.rst create mode 100644 src/hypercorn/middleware/proxy_fix.py create mode 100644 tests/middleware/test_proxy_fix.py diff --git a/docs/how_to_guides/index.rst b/docs/how_to_guides/index.rst index bccdd541..9f4bf2d7 100644 --- a/docs/how_to_guides/index.rst +++ b/docs/how_to_guides/index.rst @@ -11,6 +11,7 @@ How to guides dispatch_apps.rst http_https_redirect.rst logging.rst + proxy_fix.rst server_names.rst statsd.rst wsgi_apps.rst diff --git a/docs/how_to_guides/proxy_fix.rst b/docs/how_to_guides/proxy_fix.rst new file mode 100644 index 00000000..dd8d080f --- /dev/null +++ b/docs/how_to_guides/proxy_fix.rst @@ -0,0 +1,33 @@ +Fixing proxy headers +==================== + +If you are serving Hypercorn behind a proxy e.g. a load balancer the +client-address, scheme, and host-header will match that of the +connection between the proxy and Hypercorn rather than the user-agent +(client). However, most proxies provide headers with the original +user-agent (client) values which can be used to "fix" the headers to +these values. + +Modern proxies should provide this information via a ``Forwarded`` +header from `RFC 7239 +`_. However, this is +rare in practice with legacy proxies using a combination of +``X-Forwarded-For``, ``X-Forwarded-Proto`` and +``X-Forwarded-Host``. It is important that you chose the correct mode +(legacy, or modern) based on the proxy you use. + +To use the proxy fix middleware behind a single legacy proxy simply +wrap your app and serve the wrapped app, + +.. code-block:: python + + from hypercorn.middleware import ProxyFixMiddleware + + fixed_app = ProxyFixMiddleware(app, mode="legacy", trusted_hops=1) + +.. warning:: + + The mode and number of trusted hops must match your setup or the + user-agent (client) may be trusted and hence able to set + alternative for, proto, and host values. This can, depending on + your usage in the app, lead to security vulnerabilities. diff --git a/src/hypercorn/middleware/__init__.py b/src/hypercorn/middleware/__init__.py index 83ea29c7..e7f017c1 100644 --- a/src/hypercorn/middleware/__init__.py +++ b/src/hypercorn/middleware/__init__.py @@ -2,11 +2,13 @@ from .dispatcher import DispatcherMiddleware from .http_to_https import HTTPToHTTPSRedirectMiddleware +from .proxy_fix import ProxyFixMiddleware from .wsgi import AsyncioWSGIMiddleware, TrioWSGIMiddleware __all__ = ( "AsyncioWSGIMiddleware", "DispatcherMiddleware", "HTTPToHTTPSRedirectMiddleware", + "ProxyFixMiddleware", "TrioWSGIMiddleware", ) diff --git a/src/hypercorn/middleware/proxy_fix.py b/src/hypercorn/middleware/proxy_fix.py new file mode 100644 index 00000000..509941cf --- /dev/null +++ b/src/hypercorn/middleware/proxy_fix.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Callable, Iterable, Literal, Optional, Tuple + +from ..typing import ASGIFramework, Scope + + +class ProxyFixMiddleware: + def __init__( + self, + app: ASGIFramework, + mode: Literal["legacy", "modern"] = "legacy", + trusted_hops: int = 1, + ) -> None: + self.app = app + self.mode = mode + self.trusted_hops = trusted_hops + + async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: + if scope["type"] in {"http", "websocket"}: + scope = deepcopy(scope) + headers = scope["headers"] # type: ignore + client: Optional[str] = None + scheme: Optional[str] = None + host: Optional[str] = None + + if ( + self.mode == "modern" + and (value := _get_trusted_value(b"forwarded", headers, self.trusted_hops)) + is not None + ): + for part in value.split(";"): + if part.startswith("for="): + client = part[4:].strip() + elif part.startswith("host="): + host = part[5:].strip() + elif part.startswith("proto="): + scheme = part[6:].strip() + + else: + client = _get_trusted_value(b"x-forwarded-for", headers, self.trusted_hops) + scheme = _get_trusted_value(b"x-forwarded-proto", headers, self.trusted_hops) + host = _get_trusted_value(b"x-forwarded-host", headers, self.trusted_hops) + + if client is not None: + scope["client"] = (client, 0) # type: ignore + + if scheme is not None: + scope["scheme"] = scheme # type: ignore + + if host is not None: + headers = [ + (name, header_value) + for name, header_value in headers + if name.lower() != b"host" + ] + headers.append((b"host", host)) + scope["headers"] = headers # type: ignore + + await self.app(scope, receive, send) + + +def _get_trusted_value( + name: bytes, headers: Iterable[Tuple[bytes, bytes]], trusted_hops: int +) -> Optional[str]: + if trusted_hops == 0: + return None + + values = [] + for header_name, header_value in headers: + if header_name.lower() == name: + values.extend([value.decode("latin1").strip() for value in header_value.split(b",")]) + + if len(values) >= trusted_hops: + return values[-trusted_hops] + + return None diff --git a/tests/middleware/test_proxy_fix.py b/tests/middleware/test_proxy_fix.py new file mode 100644 index 00000000..2a43b589 --- /dev/null +++ b/tests/middleware/test_proxy_fix.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from hypercorn.middleware import ProxyFixMiddleware +from hypercorn.typing import HTTPScope + + +@pytest.mark.asyncio +async def test_proxy_fix_legacy() -> None: + mock = AsyncMock() + app = ProxyFixMiddleware(mock) + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "2", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [ + (b"x-forwarded-for", b"127.0.0.1"), + (b"x-forwarded-for", b"127.0.0.2"), + (b"x-forwarded-proto", b"http,https"), + ], + "client": ("127.0.0.3", 80), + "server": None, + "extensions": {}, + } + await app(scope, None, None) + mock.assert_called() + assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0) + assert mock.call_args[0][0]["scheme"] == "https" + + +@pytest.mark.asyncio +async def test_proxy_fix_modern() -> None: + mock = AsyncMock() + app = ProxyFixMiddleware(mock, mode="modern") + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "2", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [ + (b"forwarded", b"for=127.0.0.1;proto=http,for=127.0.0.2;proto=https"), + ], + "client": ("127.0.0.3", 80), + "server": None, + "extensions": {}, + } + await app(scope, None, None) + mock.assert_called() + assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0) + assert mock.call_args[0][0]["scheme"] == "https" From 125bb002903c0e2d60ab6cb36d00dba3cfad6d03 Mon Sep 17 00:00:00 2001 From: pgjones Date: Thu, 28 Dec 2023 12:00:54 +0000 Subject: [PATCH 12/21] Switch wsgi.errors to stdout This matches other examples and the WSGI specification. --- src/hypercorn/app_wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hypercorn/app_wrappers.py b/src/hypercorn/app_wrappers.py index cfc41cfd..633abb1f 100644 --- a/src/hypercorn/app_wrappers.py +++ b/src/hypercorn/app_wrappers.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from functools import partial from io import BytesIO from typing import Callable, List, Optional, Tuple @@ -141,7 +142,7 @@ def _build_environ(scope: HTTPScope, body: bytes) -> dict: "wsgi.version": (1, 0), "wsgi.url_scheme": scope.get("scheme", "http"), "wsgi.input": BytesIO(body), - "wsgi.errors": BytesIO(), + "wsgi.errors": sys.stdout, "wsgi.multithread": True, "wsgi.multiprocess": True, "wsgi.run_once": False, From c0468e555c6e6617dd92377c2c2efff862268de7 Mon Sep 17 00:00:00 2001 From: Florian Apolloner Date: Fri, 29 Dec 2023 11:21:36 +0100 Subject: [PATCH 13/21] Remove old warning --- docs/how_to_guides/wsgi_apps.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/how_to_guides/wsgi_apps.rst b/docs/how_to_guides/wsgi_apps.rst index 396d8909..df6b72bd 100644 --- a/docs/how_to_guides/wsgi_apps.rst +++ b/docs/how_to_guides/wsgi_apps.rst @@ -9,12 +9,6 @@ Hypercorn directly serves WSGI applications: $ hypercorn module:wsgi_app -.. warning:: - - The full response from the WSGI app will be stored in memory - before being sent. This prevents the WSGI app from streaming a - response. - WSGI Middleware --------------- From 7c39c68b61012a3c30979176080861c8b00fb229 Mon Sep 17 00:00:00 2001 From: pgjones Date: Mon, 1 Jan 2024 02:47:10 +0000 Subject: [PATCH 14/21] Support restarting workers after max requests This is useful as a "solution" to memory leaks in apps as it ensures that after the max requests have been handled the worker will restart hence freeing any memory leak. The options match those used by Gunicorn. This also ensures that the workers self-heal such that if a worker crashes it will be restored. --- src/hypercorn/__main__.py | 17 ++++++ src/hypercorn/asyncio/run.py | 23 ++++++- src/hypercorn/asyncio/worker_context.py | 15 ++++- src/hypercorn/config.py | 2 + src/hypercorn/protocol/h11.py | 1 + src/hypercorn/protocol/h2.py | 1 + src/hypercorn/protocol/h3.py | 1 + src/hypercorn/run.py | 81 ++++++++++++++++--------- src/hypercorn/trio/run.py | 7 ++- src/hypercorn/trio/worker_context.py | 15 ++++- src/hypercorn/typing.py | 4 ++ src/hypercorn/utils.py | 30 ++++----- tests/asyncio/test_keep_alive.py | 2 +- tests/asyncio/test_sanity.py | 8 +-- tests/asyncio/test_tcp_server.py | 4 +- tests/protocol/test_h11.py | 2 + tests/protocol/test_h2.py | 4 +- tests/protocol/test_http_stream.py | 2 +- tests/protocol/test_ws_stream.py | 2 +- tests/trio/test_keep_alive.py | 2 +- tests/trio/test_sanity.py | 8 +-- 21 files changed, 163 insertions(+), 68 deletions(-) diff --git a/src/hypercorn/__main__.py b/src/hypercorn/__main__.py index 769a5e09..aed33b12 100644 --- a/src/hypercorn/__main__.py +++ b/src/hypercorn/__main__.py @@ -89,6 +89,19 @@ def main(sys_args: Optional[List[str]] = None) -> int: default=sentinel, type=int, ) + parser.add_argument( + "--max-requests", + help="""Maximum number of requests a worker will process before restarting""", + default=sentinel, + type=int, + ) + parser.add_argument( + "--max-requests-jitter", + help="This jitter causes the max-requests per worker to be " + "randomized by randint(0, max_requests_jitter)", + default=sentinel, + type=int, + ) parser.add_argument( "-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int ) @@ -252,6 +265,10 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: config.keyfile_password = args.keyfile_password if args.log_config is not sentinel: config.logconfig = args.log_config + if args.max_requests is not sentinel: + config.max_requests = args.max_requests + if args.max_requests_jitter is not sentinel: + config.max_requests_jitter = args.max_requests if args.pid is not sentinel: config.pid_path = args.pid if args.root_path is not sentinel: diff --git a/src/hypercorn/asyncio/run.py b/src/hypercorn/asyncio/run.py index 7c0982d9..c633c5bd 100644 --- a/src/hypercorn/asyncio/run.py +++ b/src/hypercorn/asyncio/run.py @@ -4,9 +4,11 @@ import platform import signal import ssl +import sys from functools import partial from multiprocessing.synchronize import Event as EventType from os import getpid +from random import randint from socket import socket from typing import Any, Awaitable, Callable, Optional, Set @@ -30,6 +32,14 @@ except ImportError: from taskgroup import Runner # type: ignore +try: + from asyncio import TaskGroup +except ImportError: + from taskgroup import TaskGroup # type: ignore + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + def _share_socket(sock: socket) -> socket: # Windows requires the socket be explicitly shared across @@ -84,7 +94,10 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 ssl_context = config.create_ssl_context() ssl_handshake_timeout = config.ssl_handshake_timeout - context = WorkerContext() + max_requests = None + if config.max_requests is not None: + max_requests = config.max_requests + randint(0, config.max_requests_jitter) + context = WorkerContext(max_requests) server_tasks: Set[asyncio.Task] = set() async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: @@ -136,7 +149,13 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)") try: - await raise_shutdown(shutdown_trigger) + async with TaskGroup() as task_group: + task_group.create_task(raise_shutdown(shutdown_trigger)) + task_group.create_task(raise_shutdown(context.terminate.wait)) + except BaseExceptionGroup as error: + _, other_errors = error.split((ShutdownError, KeyboardInterrupt)) + if other_errors is not None: + raise other_errors except (ShutdownError, KeyboardInterrupt): pass finally: diff --git a/src/hypercorn/asyncio/worker_context.py b/src/hypercorn/asyncio/worker_context.py index fe9ad1c7..d16f76ba 100644 --- a/src/hypercorn/asyncio/worker_context.py +++ b/src/hypercorn/asyncio/worker_context.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import Type, Union +from typing import Optional, Type, Union from ..typing import Event @@ -26,9 +26,20 @@ def is_set(self) -> bool: class WorkerContext: event_class: Type[Event] = EventWrapper - def __init__(self) -> None: + def __init__(self, max_requests: Optional[int]) -> None: + self.max_requests = max_requests + self.requests = 0 + self.terminate = self.event_class() self.terminated = self.event_class() + async def mark_request(self) -> None: + if self.max_requests is None: + return + + self.requests += 1 + if self.requests > self.max_requests: + await self.terminate.set() + @staticmethod async def sleep(wait: Union[float, int]) -> None: return await asyncio.sleep(wait) diff --git a/src/hypercorn/config.py b/src/hypercorn/config.py index fdc7a413..f00c7d5e 100644 --- a/src/hypercorn/config.py +++ b/src/hypercorn/config.py @@ -92,6 +92,8 @@ class Config: logger_class = Logger loglevel: str = "INFO" max_app_queue_size: int = 10 + max_requests: Optional[int] = None + max_requests_jitter: int = 0 pid_path: Optional[str] = None server_names: List[str] = [] shutdown_timeout = 60 * SECONDS diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index ec04593a..a33ad4ad 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -236,6 +236,7 @@ async def _create_stream(self, request: h11.Request) -> None: ) ) self.keep_alive_requests += 1 + await self.context.mark_request() async def _send_h11_event(self, event: H11SendableEvent) -> None: try: diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 26048780..9c92ab3d 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -354,6 +354,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: ) ) self.keep_alive_requests += 1 + await self.context.mark_request() async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] diff --git a/src/hypercorn/protocol/h3.py b/src/hypercorn/protocol/h3.py index 88d9a4d3..151c0667 100644 --- a/src/hypercorn/protocol/h3.py +++ b/src/hypercorn/protocol/h3.py @@ -125,6 +125,7 @@ async def _create_stream(self, request: HeadersReceived) -> None: raw_path=raw_path, ) ) + await self.context.mark_request() async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] diff --git a/src/hypercorn/run.py b/src/hypercorn/run.py index 05ab2391..cfe801aa 100644 --- a/src/hypercorn/run.py +++ b/src/hypercorn/run.py @@ -4,6 +4,7 @@ import signal import time from multiprocessing import get_context +from multiprocessing.connection import wait from multiprocessing.context import BaseContext from multiprocessing.process import BaseProcess from multiprocessing.synchronize import Event as EventType @@ -12,12 +13,10 @@ from .config import Config, Sockets from .typing import WorkerFunc -from .utils import load_application, wait_for_changes, write_pid_file +from .utils import check_for_updates, files_to_watch, load_application, write_pid_file def run(config: Config) -> int: - exit_code = 0 - if config.pid_path is not None: write_pid_file(config.pid_path) @@ -42,67 +41,82 @@ def run(config: Config) -> int: if config.use_reloader and config.workers == 0: raise RuntimeError("Cannot reload without workers") - if config.use_reloader or config.workers == 0: - # Load the application so that the correct paths are checked for - # changes, but only when the reloader is being used. - load_application(config.application_path, config.wsgi_max_body_size) - + exitcode = 0 if config.workers == 0: worker_func(config, sockets) else: + if config.use_reloader: + # Load the application so that the correct paths are checked for + # changes, but only when the reloader is being used. + load_application(config.application_path, config.wsgi_max_body_size) + ctx = get_context("spawn") active = True + shutdown_event = ctx.Event() + + def shutdown(*args: Any) -> None: + nonlocal active, shutdown_event + shutdown_event.set() + active = False + + processes: List[BaseProcess] = [] while active: # Ignore SIGINT before creating the processes, so that they # inherit the signal handling. This means that the shutdown # function controls the shutdown. signal.signal(signal.SIGINT, signal.SIG_IGN) - shutdown_event = ctx.Event() - processes = start_processes(config, worker_func, sockets, shutdown_event, ctx) - - def shutdown(*args: Any) -> None: - nonlocal active, shutdown_event - shutdown_event.set() - active = False + _populate(processes, config, worker_func, sockets, shutdown_event, ctx) for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}: if hasattr(signal, signal_name): signal.signal(getattr(signal, signal_name), shutdown) if config.use_reloader: - wait_for_changes(shutdown_event) - shutdown_event.set() + files = files_to_watch() + while True: + finished = wait((process.sentinel for process in processes), timeout=1) + updated = check_for_updates(files) + if updated: + shutdown_event.set() + for process in processes: + process.join() + shutdown_event.clear() + break + if len(finished) > 0: + break else: - active = False + wait(process.sentinel for process in processes) - for process in processes: - process.join() - if process.exitcode != 0: - exit_code = process.exitcode + exitcode = _join_exited(processes) + if exitcode != 0: + shutdown_event.set() + active = False for process in processes: process.terminate() + exitcode = _join_exited(processes) if exitcode != 0 else exitcode + for sock in sockets.secure_sockets: sock.close() for sock in sockets.insecure_sockets: sock.close() - return exit_code + return exitcode -def start_processes( +def _populate( + processes: List[BaseProcess], config: Config, worker_func: WorkerFunc, sockets: Sockets, shutdown_event: EventType, ctx: BaseContext, -) -> List[BaseProcess]: - processes = [] - for _ in range(config.workers): +) -> None: + for _ in range(config.workers - len(processes)): process = ctx.Process( # type: ignore target=worker_func, kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets}, @@ -117,4 +131,15 @@ def start_processes( processes.append(process) if platform.system() == "Windows": time.sleep(0.1) - return processes + + +def _join_exited(processes: List[BaseProcess]) -> int: + exitcode = 0 + for index in reversed(range(len(processes))): + worker = processes[index] + if worker.exitcode is not None: + worker.join() + exitcode = worker.exitcode if exitcode == 0 else exitcode + del processes[index] + + return exitcode diff --git a/src/hypercorn/trio/run.py b/src/hypercorn/trio/run.py index d8721bbb..2cfe5db4 100644 --- a/src/hypercorn/trio/run.py +++ b/src/hypercorn/trio/run.py @@ -3,6 +3,7 @@ import sys from functools import partial from multiprocessing.synchronize import Event as EventType +from random import randint from typing import Awaitable, Callable, Optional import trio @@ -37,7 +38,10 @@ async def worker_serve( config.set_statsd_logger_class(StatsdLogger) lifespan = Lifespan(app, config) - context = WorkerContext() + max_requests = None + if config.max_requests is not None: + max_requests = config.max_requests + randint(0, config.max_requests_jitter) + context = WorkerContext(max_requests) async with trio.open_nursery() as lifespan_nursery: await lifespan_nursery.start(lifespan.handle_lifespan) @@ -82,6 +86,7 @@ async def worker_serve( async with trio.open_nursery(strict_exception_groups=True) as nursery: if shutdown_trigger is not None: nursery.start_soon(raise_shutdown, shutdown_trigger) + nursery.start_soon(raise_shutdown, context.terminate.wait) nursery.start_soon( partial( diff --git a/src/hypercorn/trio/worker_context.py b/src/hypercorn/trio/worker_context.py index bcfa1a51..c09c4fb6 100644 --- a/src/hypercorn/trio/worker_context.py +++ b/src/hypercorn/trio/worker_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Type, Union +from typing import Optional, Type, Union import trio @@ -27,9 +27,20 @@ def is_set(self) -> bool: class WorkerContext: event_class: Type[Event] = EventWrapper - def __init__(self) -> None: + def __init__(self, max_requests: Optional[int]) -> None: + self.max_requests = max_requests + self.requests = 0 + self.terminate = self.event_class() self.terminated = self.event_class() + async def mark_request(self) -> None: + if self.max_requests is None: + return + + self.requests += 1 + if self.requests > self.max_requests: + await self.terminate.set() + @staticmethod async def sleep(wait: Union[float, int]) -> None: return await trio.sleep(wait) diff --git a/src/hypercorn/typing.py b/src/hypercorn/typing.py index 1299a776..2ebb711d 100644 --- a/src/hypercorn/typing.py +++ b/src/hypercorn/typing.py @@ -290,8 +290,12 @@ def is_set(self) -> bool: class WorkerContext(Protocol): event_class: Type[Event] + terminate: Event terminated: Event + async def mark_request(self) -> None: + ... + @staticmethod async def sleep(wait: Union[float, int]) -> None: ... diff --git a/src/hypercorn/utils.py b/src/hypercorn/utils.py index 9e3520d7..39249c53 100644 --- a/src/hypercorn/utils.py +++ b/src/hypercorn/utils.py @@ -4,7 +4,6 @@ import os import socket import sys -import time from enum import Enum from importlib import import_module from multiprocessing.synchronize import Event as EventType @@ -133,7 +132,7 @@ def wrap_app( return WSGIWrapper(cast(WSGIFramework, app), wsgi_max_body_size) -def wait_for_changes(shutdown_event: EventType) -> None: +def files_to_watch() -> Dict[Path, float]: last_updates: Dict[Path, float] = {} for module in list(sys.modules.values()): filename = getattr(module, "__file__", None) @@ -144,24 +143,21 @@ def wait_for_changes(shutdown_event: EventType) -> None: last_updates[Path(filename)] = path.stat().st_mtime except (FileNotFoundError, NotADirectoryError): pass + return last_updates - while not shutdown_event.is_set(): - time.sleep(1) - for index, (path, last_mtime) in enumerate(last_updates.items()): - if index % 10 == 0: - # Yield to the event loop - time.sleep(0) - - try: - mtime = path.stat().st_mtime - except FileNotFoundError: - return +def check_for_updates(files: Dict[Path, float]) -> bool: + for path, last_mtime in files.items(): + try: + mtime = path.stat().st_mtime + except FileNotFoundError: + return True + else: + if mtime > last_mtime: + return True else: - if mtime > last_mtime: - return - else: - last_updates[path] = mtime + files[path] = mtime + return False async def raise_shutdown(shutdown_event: Callable[..., Awaitable]) -> None: diff --git a/tests/asyncio/test_keep_alive.py b/tests/asyncio/test_keep_alive.py index 6b357f8f..a46f4cfd 100644 --- a/tests/asyncio/test_keep_alive.py +++ b/tests/asyncio/test_keep_alive.py @@ -50,7 +50,7 @@ async def _server(event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[TCPSe ASGIWrapper(slow_framework), event_loop, config, - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) diff --git a/tests/asyncio/test_sanity.py b/tests/asyncio/test_sanity.py index 2d7cb0bc..cde29297 100644 --- a/tests/asyncio/test_sanity.py +++ b/tests/asyncio/test_sanity.py @@ -21,7 +21,7 @@ async def test_http1_request(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -78,7 +78,7 @@ async def test_http1_websocket(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -115,7 +115,7 @@ async def test_http2_request(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) @@ -178,7 +178,7 @@ async def test_http2_websocket(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) diff --git a/tests/asyncio/test_tcp_server.py b/tests/asyncio/test_tcp_server.py index afe00c20..1dfd4212 100644 --- a/tests/asyncio/test_tcp_server.py +++ b/tests/asyncio/test_tcp_server.py @@ -18,7 +18,7 @@ async def test_completes_on_closed(event_loop: asyncio.AbstractEventLoop) -> Non ASGIWrapper(echo_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -34,7 +34,7 @@ async def test_complets_on_half_close(event_loop: asyncio.AbstractEventLoop) -> ASGIWrapper(echo_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) diff --git a/tests/protocol/test_h11.py b/tests/protocol/test_h11.py index 27c80b56..09e85b78 100755 --- a/tests/protocol/test_h11.py +++ b/tests/protocol/test_h11.py @@ -35,6 +35,8 @@ async def _protocol(monkeypatch: MonkeyPatch) -> H11Protocol: monkeypatch.setattr(hypercorn.protocol.h11, "HTTPStream", MockHTTPStream) context = Mock() context.event_class.return_value = AsyncMock(spec=IOEvent) + context.mark_request = AsyncMock() + context.terminate = context.event_class() context.terminated = context.event_class() context.terminated.is_set.return_value = False return H11Protocol(AsyncMock(), Config(), context, AsyncMock(), False, None, None, AsyncMock()) diff --git a/tests/protocol/test_h2.py b/tests/protocol/test_h2.py index 77bcf515..cec6c263 100644 --- a/tests/protocol/test_h2.py +++ b/tests/protocol/test_h2.py @@ -75,7 +75,7 @@ async def test_stream_buffer_complete(event_loop: asyncio.AbstractEventLoop) -> @pytest.mark.asyncio async def test_protocol_handle_protocol_error() -> None: protocol = H2Protocol( - Mock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock() + Mock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock() ) await protocol.handle(RawData(data=b"broken nonsense\r\n\r\n")) protocol.send.assert_awaited() # type: ignore @@ -85,7 +85,7 @@ async def test_protocol_handle_protocol_error() -> None: @pytest.mark.asyncio async def test_protocol_keep_alive_max_requests() -> None: protocol = H2Protocol( - Mock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock() + Mock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock() ) protocol.config.keep_alive_max_requests = 0 client = H2Connection() diff --git a/tests/protocol/test_http_stream.py b/tests/protocol/test_http_stream.py index 24af5969..6f656de0 100644 --- a/tests/protocol/test_http_stream.py +++ b/tests/protocol/test_http_stream.py @@ -31,7 +31,7 @@ @pytest_asyncio.fixture(name="stream") # type: ignore[misc] async def _stream() -> HTTPStream: stream = HTTPStream( - AsyncMock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock(), 1 + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) stream.app_put = AsyncMock() stream.config._log = AsyncMock(spec=Logger) diff --git a/tests/protocol/test_ws_stream.py b/tests/protocol/test_ws_stream.py index 5f595828..05403130 100644 --- a/tests/protocol/test_ws_stream.py +++ b/tests/protocol/test_ws_stream.py @@ -165,7 +165,7 @@ def test_handshake_accept_additional_headers() -> None: @pytest_asyncio.fixture(name="stream") # type: ignore[misc] async def _stream() -> WSStream: stream = WSStream( - AsyncMock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock(), 1 + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) stream.task_group.spawn_app.return_value = AsyncMock() # type: ignore stream.app_put = AsyncMock() diff --git a/tests/trio/test_keep_alive.py b/tests/trio/test_keep_alive.py index d30d82db..6bed437f 100644 --- a/tests/trio/test_keep_alive.py +++ b/tests/trio/test_keep_alive.py @@ -47,7 +47,7 @@ def _client_stream( config.keep_alive_timeout = KEEP_ALIVE_TIMEOUT client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(None), server_stream) nursery.start_soon(server.run) yield client_stream diff --git a/tests/trio/test_sanity.py b/tests/trio/test_sanity.py index 6d4be8c5..b5bf75ba 100644 --- a/tests/trio/test_sanity.py +++ b/tests/trio/test_sanity.py @@ -25,7 +25,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = h11.Connection(h11.CLIENT) await client_stream.send_all( @@ -76,7 +76,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = wsproto.WSConnection(wsproto.ConnectionType.CLIENT) await client_stream.send_all(client.send(wsproto.events.Request(host="hypercorn", target="/"))) @@ -103,7 +103,7 @@ async def test_http2_request(nursery: trio._core._run.Nursery) -> None: server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) server_stream.do_handshake = AsyncMock() server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = h2.connection.H2Connection() client.initiate_connection() @@ -158,7 +158,7 @@ async def test_http2_websocket(nursery: trio._core._run.Nursery) -> None: server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) server_stream.do_handshake = AsyncMock() server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) h2_client = h2.connection.H2Connection() h2_client.initiate_connection() From 0bb4fb9de5e00dbaece82a6c02617d1d9c0c8e56 Mon Sep 17 00:00:00 2001 From: pgjones Date: Mon, 1 Jan 2024 13:41:10 +0000 Subject: [PATCH 15/21] Don't error on LocalProtoclErrors for ws streams There is a race condition being hit in the autobahn compliance tests whereby the client closes the stream and the server responds with an acknowledgement. Whilst the server responds the app sends a message, which now errors as the WSConnection state is closed. As the state is managed by the WSStream, rather than the WSConnection it makes sense to ignore these errors. --- src/hypercorn/protocol/ws_stream.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/hypercorn/protocol/ws_stream.py b/src/hypercorn/protocol/ws_stream.py index 9011999e..5709952a 100644 --- a/src/hypercorn/protocol/ws_stream.py +++ b/src/hypercorn/protocol/ws_stream.py @@ -18,7 +18,7 @@ from wsproto.extensions import Extension, PerMessageDeflate from wsproto.frame_protocol import CloseReason from wsproto.handshake import server_extensions_handshake, WEBSOCKET_VERSION -from wsproto.utilities import generate_accept_token, split_comma_header +from wsproto.utilities import generate_accept_token, LocalProtocolError, split_comma_header from .events import Body, Data, EndBody, EndData, Event, Request, Response, StreamClosed from ..config import Config @@ -333,8 +333,12 @@ async def _send_error_response(self, status_code: int) -> None: ) async def _send_wsproto_event(self, event: WSProtoEvent) -> None: - data = self.connection.send(event) - await self.send(Data(stream_id=self.stream_id, data=data)) + try: + data = self.connection.send(event) + except LocalProtocolError: + pass + else: + await self.send(Data(stream_id=self.stream_id, data=data)) async def _accept(self, message: WebsocketAcceptEvent) -> None: self.state = ASGIWebsocketState.CONNECTED From f8e4e5de3aec7f8eb986535163c3d5b4f424465c Mon Sep 17 00:00:00 2001 From: pgjones Date: Mon, 1 Jan 2024 13:49:48 +0000 Subject: [PATCH 16/21] Bump and release 0.16.0 --- CHANGELOG.rst | 21 +++++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e18836c1..a68991e3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,24 @@ +0.16.0 2023-01-01 +----------------- + +* Add a max keep alive requests configuration option, this mitigates + the HTTP/2 rapid reset attack. +* Return subprocess exit code if non-zero. +* Add ProxyFix middleware to make it easier to run Hypercorn behind a + proxy. +* Support restarting workers after max requests to make it easier to + manage memory leaks in apps. +* Bugfix ensure the idle task is stopped on error. +* Bugfix revert autoreload error because reausing old sockets. +* Bugfix send the hinted error from h11 on RemoteProtocolErrors. +* Bugfix handle asyncio.CancelledError when socket is closed without + flushing. +* Bugfix improve WSGI compliance by closing iterators, only sending + headers on first response byte, erroring if ``start_response`` is + not called, and switching wsgi.errors to stdout. +* Don't error on LocalProtoclErrors for ws streams to better cope with + race conditions. + 0.15.0 2023-10-29 ----------------- diff --git a/pyproject.toml b/pyproject.toml index a620dfa9..37d199df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "Hypercorn" -version = "0.15.0" +version = "0.16.0" description = "A ASGI Server based on Hyper libraries and inspired by Gunicorn" authors = ["pgjones "] classifiers = [ From c0279ce3d55224334233fd54e09da252c0a0cc47 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 28 Nov 2023 10:49:37 +0400 Subject: [PATCH 17/21] Rename ssl to tls The ssl name clashes with the ssl module name and I intend to reuse this variable to carry information about the ASGI TLS extension. --- src/hypercorn/asyncio/tcp_server.py | 6 +++--- src/hypercorn/protocol/__init__.py | 12 ++++++------ src/hypercorn/protocol/h11.py | 8 ++++---- src/hypercorn/protocol/h2.py | 8 ++++---- src/hypercorn/protocol/http_stream.py | 4 ++-- src/hypercorn/protocol/ws_stream.py | 4 ++-- src/hypercorn/trio/tcp_server.py | 6 +++--- 7 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index ed9d710f..12fe013b 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -47,10 +47,10 @@ async def run(self) -> None: server = parse_socket_addr(socket.family, socket.getsockname()) ssl_object = self.writer.get_extra_info("ssl_object") if ssl_object is not None: - ssl = True + tls = True alpn_protocol = ssl_object.selected_alpn_protocol() else: - ssl = False + tls = False alpn_protocol = "http/1.1" async with TaskGroup(self.loop) as task_group: @@ -59,7 +59,7 @@ async def run(self) -> None: self.config, self.context, task_group, - ssl, + tls, client, server, self.protocol_send, diff --git a/src/hypercorn/protocol/__init__.py b/src/hypercorn/protocol/__init__.py index 39385681..8019fe5f 100755 --- a/src/hypercorn/protocol/__init__.py +++ b/src/hypercorn/protocol/__init__.py @@ -16,7 +16,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], @@ -26,7 +26,7 @@ def __init__( self.config = config self.context = context self.task_group = task_group - self.ssl = ssl + self.tls = tls self.client = client self.server = server self.send = send @@ -37,7 +37,7 @@ def __init__( self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.send, @@ -48,7 +48,7 @@ def __init__( self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.send, @@ -66,7 +66,7 @@ async def handle(self, event: Event) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.send, @@ -80,7 +80,7 @@ async def handle(self, event: Event) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.send, diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index a33ad4ad..3c898afa 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -84,7 +84,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], @@ -100,7 +100,7 @@ def __init__( self.keep_alive_requests = 0 self.send = send self.server = server - self.ssl = ssl + self.tls = tls self.stream: Optional[Union[HTTPStream, WSStream]] = None self.task_group = task_group @@ -201,7 +201,7 @@ async def _create_stream(self, request: h11.Request) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.stream_send, @@ -214,7 +214,7 @@ async def _create_stream(self, request: h11.Request) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.stream_send, diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 9c92ab3d..0b57fec7 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -84,7 +84,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], @@ -112,7 +112,7 @@ def __init__( self.keep_alive_requests = 0 self.send = send self.server = server - self.ssl = ssl + self.tls = tls self.streams: Dict[int, Union[HTTPStream, WSStream]] = {} # The below are used by the sending task self.has_data = self.context.event_class() @@ -317,7 +317,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.stream_send, @@ -329,7 +329,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.stream_send, diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index d244e7c3..fbfcdb2e 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -42,7 +42,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], @@ -56,7 +56,7 @@ def __init__( self.response: HTTPResponseStartEvent self.scope: HTTPScope self.send = send - self.scheme = "https" if ssl else "http" + self.scheme = "https" if tls else "http" self.server = server self.start_time: float self.state = ASGIHTTPState.REQUEST diff --git a/src/hypercorn/protocol/ws_stream.py b/src/hypercorn/protocol/ws_stream.py index 5709952a..6da74150 100644 --- a/src/hypercorn/protocol/ws_stream.py +++ b/src/hypercorn/protocol/ws_stream.py @@ -167,7 +167,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], @@ -185,7 +185,7 @@ def __init__( self.scope: WebsocketScope self.send = send # RFC 8441 for HTTP/2 says use http or https, ASGI says ws or wss - self.scheme = "wss" if ssl else "ws" + self.scheme = "wss" if tls else "ws" self.server = server self.start_time: float self.state = ASGIWebsocketState.HANDSHAKE diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index dbcc7a12..1a40bdd8 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -42,11 +42,11 @@ async def run(self) -> None: return # Handshake failed alpn_protocol = self.stream.selected_alpn_protocol() socket = self.stream.transport_stream.socket - ssl = True + tls = True except AttributeError: # Not SSL alpn_protocol = "http/1.1" socket = self.stream.socket - ssl = False + tls = False try: client = parse_socket_addr(socket.family, socket.getpeername()) @@ -59,7 +59,7 @@ async def run(self) -> None: self.config, self.context, task_group, - ssl, + tls, client, server, self.protocol_send, From ea2bbe6be22dd35cc8f05270fbdff755b0cb715a Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 28 Nov 2023 10:54:01 +0400 Subject: [PATCH 18/21] Turn tls into optional dictionary The dictionary will carry information for the TLS ASGI extension. --- src/hypercorn/asyncio/tcp_server.py | 4 ++-- src/hypercorn/protocol/__init__.py | 4 ++-- src/hypercorn/protocol/h11.py | 4 ++-- src/hypercorn/protocol/h2.py | 4 ++-- src/hypercorn/protocol/http_stream.py | 6 +++--- src/hypercorn/protocol/ws_stream.py | 6 +++--- src/hypercorn/trio/tcp_server.py | 4 ++-- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index 12fe013b..31b2d2d5 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -47,10 +47,10 @@ async def run(self) -> None: server = parse_socket_addr(socket.family, socket.getsockname()) ssl_object = self.writer.get_extra_info("ssl_object") if ssl_object is not None: - tls = True + tls = {} alpn_protocol = ssl_object.selected_alpn_protocol() else: - tls = False + tls = None alpn_protocol = "http/1.1" async with TaskGroup(self.loop) as task_group: diff --git a/src/hypercorn/protocol/__init__.py b/src/hypercorn/protocol/__init__.py index 8019fe5f..1c0c5a96 100755 --- a/src/hypercorn/protocol/__init__.py +++ b/src/hypercorn/protocol/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Callable, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Optional, Tuple, Union from .h2 import H2Protocol from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol @@ -16,7 +16,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - tls: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index 3c898afa..ca7eea4d 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Awaitable, Callable, cast, Optional, Tuple, Type, Union +from typing import Any, Awaitable, Callable, cast, Optional, Tuple, Type, Union import h11 @@ -84,7 +84,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - tls: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 0b57fec7..8b15f5e5 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union import h2 import h2.connection @@ -84,7 +84,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - tls: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index fbfcdb2e..47270c7a 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -2,7 +2,7 @@ from enum import auto, Enum from time import time -from typing import Awaitable, Callable, Optional, Tuple +from typing import Any, Awaitable, Callable, Optional, Tuple from urllib.parse import unquote from .events import Body, EndBody, Event, InformationalResponse, Request, Response, StreamClosed @@ -42,7 +42,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - tls: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], @@ -56,7 +56,7 @@ def __init__( self.response: HTTPResponseStartEvent self.scope: HTTPScope self.send = send - self.scheme = "https" if tls else "http" + self.scheme = "https" if tls is not None else "http" self.server = server self.start_time: float self.state = ASGIHTTPState.REQUEST diff --git a/src/hypercorn/protocol/ws_stream.py b/src/hypercorn/protocol/ws_stream.py index 6da74150..b0b99c89 100644 --- a/src/hypercorn/protocol/ws_stream.py +++ b/src/hypercorn/protocol/ws_stream.py @@ -3,7 +3,7 @@ from enum import auto, Enum from io import BytesIO, StringIO from time import time -from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple, Union from urllib.parse import unquote from wsproto.connection import Connection, ConnectionState, ConnectionType @@ -167,7 +167,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - tls: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], @@ -185,7 +185,7 @@ def __init__( self.scope: WebsocketScope self.send = send # RFC 8441 for HTTP/2 says use http or https, ASGI says ws or wss - self.scheme = "wss" if tls else "ws" + self.scheme = "wss" if tls is not None else "ws" self.server = server self.start_time: float self.state = ASGIWebsocketState.HANDSHAKE diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index 1a40bdd8..c3ab4418 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -42,11 +42,11 @@ async def run(self) -> None: return # Handshake failed alpn_protocol = self.stream.selected_alpn_protocol() socket = self.stream.transport_stream.socket - tls = True + tls = {} except AttributeError: # Not SSL alpn_protocol = "http/1.1" socket = self.stream.socket - tls = False + tls = None try: client = parse_socket_addr(socket.family, socket.getpeername()) From 6b6553fd5e64a91d9496f4990fc84bf63237a5fc Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 28 Nov 2023 14:02:00 +0400 Subject: [PATCH 19/21] Serve client_cert_name and alpn_protocol in tls extension --- src/hypercorn/protocol/http_stream.py | 4 ++++ src/hypercorn/trio/tcp_server.py | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index 47270c7a..ed235813 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -57,6 +57,7 @@ def __init__( self.scope: HTTPScope self.send = send self.scheme = "https" if tls is not None else "http" + self.tls = tls self.server = server self.start_time: float self.state = ASGIHTTPState.REQUEST @@ -94,6 +95,9 @@ async def handle(self, event: Event) -> None: if event.http_version in EARLY_HINTS_VERSIONS: self.scope["extensions"]["http.response.early_hint"] = {} + if self.tls is not None: + self.scope["extensions"]["tls"] = self.tls + if valid_server_name(self.config, event): self.app_put = await self.task_group.spawn_app( self.app, self.config, self.scope, self.app_send diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index c3ab4418..8819a1bd 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ssl from math import inf from typing import Any, Generator, Optional @@ -42,7 +43,13 @@ async def run(self) -> None: return # Handshake failed alpn_protocol = self.stream.selected_alpn_protocol() socket = self.stream.transport_stream.socket - tls = {} + + tls = {"alpn_protocol": alpn_protocol} + client_certificate = self.stream.getpeercert(binary_form=False) + if client_certificate: + tls["client_cert_name"] = ", ".join( + [f"{part[0][0]}={part[0][1]}" for part in client_certificate["subject"]] + ) except AttributeError: # Not SSL alpn_protocol = "http/1.1" socket = self.stream.socket From a28c72caa064642757f8d664c3abd237affd8785 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 5 Dec 2023 23:38:02 +0400 Subject: [PATCH 20/21] Add unofficial `_transport` ASGI extension It makes the underlying transport available, and depends on the actual backend. urllib3 needs it to implement CONNECT. --- src/hypercorn/asyncio/tcp_server.py | 1 + src/hypercorn/protocol/__init__.py | 4 ++++ src/hypercorn/protocol/h11.py | 3 +++ src/hypercorn/protocol/h2.py | 2 ++ src/hypercorn/protocol/http_stream.py | 5 +++++ src/hypercorn/trio/tcp_server.py | 1 + 6 files changed, 16 insertions(+) diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index 31b2d2d5..ce8e8147 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -64,6 +64,7 @@ async def run(self) -> None: server, self.protocol_send, alpn_protocol, + (self.reader, self.writer), ) await self.protocol.initiate() await self._start_idle() diff --git a/src/hypercorn/protocol/__init__.py b/src/hypercorn/protocol/__init__.py index 1c0c5a96..e047da9d 100755 --- a/src/hypercorn/protocol/__init__.py +++ b/src/hypercorn/protocol/__init__.py @@ -21,6 +21,7 @@ def __init__( server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], alpn_protocol: Optional[str] = None, + transport=None, ) -> None: self.app = app self.config = config @@ -31,6 +32,7 @@ def __init__( self.server = server self.send = send self.protocol: Union[H11Protocol, H2Protocol] + self.transport = transport if alpn_protocol == "h2": self.protocol = H2Protocol( self.app, @@ -41,6 +43,7 @@ def __init__( self.client, self.server, self.send, + self.transport, ) else: self.protocol = H11Protocol( @@ -52,6 +55,7 @@ def __init__( self.client, self.server, self.send, + self.transport, ) async def initiate(self) -> None: diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index ca7eea4d..90cb781a 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -88,6 +88,7 @@ def __init__( client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], + transport=None, ) -> None: self.app = app self.can_read = context.event_class() @@ -103,6 +104,7 @@ def __init__( self.tls = tls self.stream: Optional[Union[HTTPStream, WSStream]] = None self.task_group = task_group + self.transport = transport async def initiate(self) -> None: pass @@ -219,6 +221,7 @@ async def _create_stream(self, request: h11.Request) -> None: self.server, self.stream_send, STREAM_ID, + self.transport, ) if self.config.h11_pass_raw_headers: diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 8b15f5e5..a5b1c692 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -88,6 +88,7 @@ def __init__( client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], + transport=None, ) -> None: self.app = app self.client = client @@ -118,6 +119,7 @@ def __init__( self.has_data = self.context.event_class() self.priority = priority.PriorityTree() self.stream_buffers: Dict[int, StreamBuffer] = {} + self.transport = transport @property def idle(self) -> bool: diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index ed235813..c4b70773 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -47,6 +47,7 @@ def __init__( server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], stream_id: int, + transport=None, ) -> None: self.app = app self.client = client @@ -63,6 +64,7 @@ def __init__( self.state = ASGIHTTPState.REQUEST self.stream_id = stream_id self.task_group = task_group + self.transport = transport @property def idle(self) -> bool: @@ -98,6 +100,9 @@ async def handle(self, event: Event) -> None: if self.tls is not None: self.scope["extensions"]["tls"] = self.tls + if self.transport is not None: + self.scope["extensions"]["_transport"] = self.transport + if valid_server_name(self.config, event): self.app_put = await self.task_group.spawn_app( self.app, self.config, self.scope, self.app_send diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index 8819a1bd..e723f507 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -71,6 +71,7 @@ async def run(self) -> None: server, self.protocol_send, alpn_protocol, + self.stream, ) await self.protocol.initiate() await self._start_idle() From 01c4bdfce2aaf18de199f983eda1ed9614088439 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 5 Dec 2023 23:40:08 +0400 Subject: [PATCH 21/21] Fail loudly in case of h11/h2 errors This makes debugging much easier than the current silent behavior. --- src/hypercorn/protocol/h11.py | 1 + src/hypercorn/protocol/h2.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index 90cb781a..464c2caf 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -161,6 +161,7 @@ async def _handle_events(self) -> None: if self.connection.our_state in {h11.IDLE, h11.SEND_RESPONSE}: await self._send_error_response(error.error_status_hint) await self.send(Closed()) + raise break else: if isinstance(event, h11.Request): diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index a5b1c692..8821bf40 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -187,6 +187,7 @@ async def handle(self, event: Event) -> None: except h2.exceptions.ProtocolError: await self._flush() await self.send(Closed()) + raise else: await self._handle_events(events) elif isinstance(event, Closed):