diff --git a/.gitignore b/.gitignore index 8ef153fe..c436bdd0 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ docs/reference/source/ dist/ .coverage poetry.lock +.idea/ +.DS_Store diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e18836c1..a68991e3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,24 @@ +0.16.0 2023-01-01 +----------------- + +* Add a max keep alive requests configuration option, this mitigates + the HTTP/2 rapid reset attack. +* Return subprocess exit code if non-zero. +* Add ProxyFix middleware to make it easier to run Hypercorn behind a + proxy. +* Support restarting workers after max requests to make it easier to + manage memory leaks in apps. +* Bugfix ensure the idle task is stopped on error. +* Bugfix revert autoreload error because reausing old sockets. +* Bugfix send the hinted error from h11 on RemoteProtocolErrors. +* Bugfix handle asyncio.CancelledError when socket is closed without + flushing. +* Bugfix improve WSGI compliance by closing iterators, only sending + headers on first response byte, erroring if ``start_response`` is + not called, and switching wsgi.errors to stdout. +* Don't error on LocalProtoclErrors for ws streams to better cope with + race conditions. + 0.15.0 2023-10-29 ----------------- diff --git a/docs/discussion/dos_mitigations.rst b/docs/discussion/dos_mitigations.rst index 358ba987..88cc48bb 100644 --- a/docs/discussion/dos_mitigations.rst +++ b/docs/discussion/dos_mitigations.rst @@ -169,3 +169,14 @@ data that it cannot send to the client. To mitigate this Hypercorn responds to the backpressure and pauses (blocks) the coroutine writing the response. + +Rapid reset +^^^^^^^^^^^ + +This attack works by opening and closing streams in quick succession +in the expectation that this is more costly for the server than the +client. + +To mitigate Hypercorn will only allow a maximum number of requests per +kept-alive connection before closing it. This ensures that cost of the +attack is equally born by the client. diff --git a/docs/how_to_guides/configuring.rst b/docs/how_to_guides/configuring.rst index 26607ba7..d72e4ed7 100644 --- a/docs/how_to_guides/configuring.rst +++ b/docs/how_to_guides/configuring.rst @@ -124,6 +124,8 @@ insecure_bind ``--insecure-bind`` The TCP host/address to See *bind* for formatting options. Care must be taken! See HTTP -> HTTPS redirection docs. +keep_alive_max_requests N/A Maximum number of requests before connection 1000 + is closed. HTTP/1 & HTTP/2 only. keep_alive_timeout ``--keep-alive`` Seconds to keep inactive connections alive 5s before closing. keyfile ``--keyfile`` Path to the SSL key file. diff --git a/docs/how_to_guides/index.rst b/docs/how_to_guides/index.rst index bccdd541..9f4bf2d7 100644 --- a/docs/how_to_guides/index.rst +++ b/docs/how_to_guides/index.rst @@ -11,6 +11,7 @@ How to guides dispatch_apps.rst http_https_redirect.rst logging.rst + proxy_fix.rst server_names.rst statsd.rst wsgi_apps.rst diff --git a/docs/how_to_guides/proxy_fix.rst b/docs/how_to_guides/proxy_fix.rst new file mode 100644 index 00000000..dd8d080f --- /dev/null +++ b/docs/how_to_guides/proxy_fix.rst @@ -0,0 +1,33 @@ +Fixing proxy headers +==================== + +If you are serving Hypercorn behind a proxy e.g. a load balancer the +client-address, scheme, and host-header will match that of the +connection between the proxy and Hypercorn rather than the user-agent +(client). However, most proxies provide headers with the original +user-agent (client) values which can be used to "fix" the headers to +these values. + +Modern proxies should provide this information via a ``Forwarded`` +header from `RFC 7239 +`_. However, this is +rare in practice with legacy proxies using a combination of +``X-Forwarded-For``, ``X-Forwarded-Proto`` and +``X-Forwarded-Host``. It is important that you chose the correct mode +(legacy, or modern) based on the proxy you use. + +To use the proxy fix middleware behind a single legacy proxy simply +wrap your app and serve the wrapped app, + +.. code-block:: python + + from hypercorn.middleware import ProxyFixMiddleware + + fixed_app = ProxyFixMiddleware(app, mode="legacy", trusted_hops=1) + +.. warning:: + + The mode and number of trusted hops must match your setup or the + user-agent (client) may be trusted and hence able to set + alternative for, proto, and host values. This can, depending on + your usage in the app, lead to security vulnerabilities. diff --git a/docs/how_to_guides/wsgi_apps.rst b/docs/how_to_guides/wsgi_apps.rst index 396d8909..df6b72bd 100644 --- a/docs/how_to_guides/wsgi_apps.rst +++ b/docs/how_to_guides/wsgi_apps.rst @@ -9,12 +9,6 @@ Hypercorn directly serves WSGI applications: $ hypercorn module:wsgi_app -.. warning:: - - The full response from the WSGI app will be stored in memory - before being sent. This prevents the WSGI app from streaming a - response. - WSGI Middleware --------------- diff --git a/pyproject.toml b/pyproject.toml index a620dfa9..37d199df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "Hypercorn" -version = "0.15.0" +version = "0.16.0" description = "A ASGI Server based on Hyper libraries and inspired by Gunicorn" authors = ["pgjones "] classifiers = [ diff --git a/src/hypercorn/__main__.py b/src/hypercorn/__main__.py index b3dc0e80..aed33b12 100644 --- a/src/hypercorn/__main__.py +++ b/src/hypercorn/__main__.py @@ -23,7 +23,7 @@ def _load_config(config_path: Optional[str]) -> Config: return Config.from_toml(config_path) -def main(sys_args: Optional[List[str]] = None) -> None: +def main(sys_args: Optional[List[str]] = None) -> int: parser = argparse.ArgumentParser() parser.add_argument( "application", help="The application to dispatch to as path.to.module:instance.path" @@ -89,6 +89,19 @@ def main(sys_args: Optional[List[str]] = None) -> None: default=sentinel, type=int, ) + parser.add_argument( + "--max-requests", + help="""Maximum number of requests a worker will process before restarting""", + default=sentinel, + type=int, + ) + parser.add_argument( + "--max-requests-jitter", + help="This jitter causes the max-requests per worker to be " + "randomized by randint(0, max_requests_jitter)", + default=sentinel, + type=int, + ) parser.add_argument( "-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int ) @@ -252,6 +265,10 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: config.keyfile_password = args.keyfile_password if args.log_config is not sentinel: config.logconfig = args.log_config + if args.max_requests is not sentinel: + config.max_requests = args.max_requests + if args.max_requests_jitter is not sentinel: + config.max_requests_jitter = args.max_requests if args.pid is not sentinel: config.pid_path = args.pid if args.root_path is not sentinel: @@ -284,8 +301,8 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: if len(args.server_names) > 0: config.server_names = args.server_names - run(config) + return run(config) if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/src/hypercorn/app_wrappers.py b/src/hypercorn/app_wrappers.py index 769e014b..633abb1f 100644 --- a/src/hypercorn/app_wrappers.py +++ b/src/hypercorn/app_wrappers.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from functools import partial from io import BytesIO from typing import Callable, List, Optional, Tuple @@ -84,6 +85,8 @@ async def handle_http( def run_app(self, environ: dict, send: Callable) -> None: headers: List[Tuple[bytes, bytes]] + headers_sent = False + response_started = False status_code: Optional[int] = None def start_response( @@ -91,7 +94,7 @@ def start_response( response_headers: List[Tuple[str, str]], exc_info: Optional[Exception] = None, ) -> None: - nonlocal headers, status_code + nonlocal headers, response_started, status_code raw, _ = status.split(" ", 1) status_code = int(raw) @@ -99,10 +102,23 @@ def start_response( (name.lower().encode("ascii"), value.encode("ascii")) for name, value in response_headers ] - send({"type": "http.response.start", "status": status_code, "headers": headers}) + response_started = True - for output in self.app(environ, start_response): - send({"type": "http.response.body", "body": output, "more_body": True}) + response_body = self.app(environ, start_response) + + if not response_started: + raise RuntimeError("WSGI app did not call start_response") + + try: + for output in response_body: + if not headers_sent: + send({"type": "http.response.start", "status": status_code, "headers": headers}) + headers_sent = True + + send({"type": "http.response.body", "body": output, "more_body": True}) + finally: + if hasattr(response_body, "close"): + response_body.close() def _build_environ(scope: HTTPScope, body: bytes) -> dict: @@ -126,7 +142,7 @@ def _build_environ(scope: HTTPScope, body: bytes) -> dict: "wsgi.version": (1, 0), "wsgi.url_scheme": scope.get("scheme", "http"), "wsgi.input": BytesIO(body), - "wsgi.errors": BytesIO(), + "wsgi.errors": sys.stdout, "wsgi.multithread": True, "wsgi.multiprocess": True, "wsgi.run_once": False, diff --git a/src/hypercorn/asyncio/run.py b/src/hypercorn/asyncio/run.py index 47745383..c633c5bd 100644 --- a/src/hypercorn/asyncio/run.py +++ b/src/hypercorn/asyncio/run.py @@ -4,9 +4,11 @@ import platform import signal import ssl +import sys from functools import partial from multiprocessing.synchronize import Event as EventType from os import getpid +from random import randint from socket import socket from typing import Any, Awaitable, Callable, Optional, Set @@ -30,6 +32,14 @@ except ImportError: from taskgroup import Runner # type: ignore +try: + from asyncio import TaskGroup +except ImportError: + from taskgroup import TaskGroup # type: ignore + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + def _share_socket(sock: socket) -> socket: # Windows requires the socket be explicitly shared across @@ -65,7 +75,7 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 # Add signal handler may not be implemented on Windows signal.signal(getattr(signal, signal_name), _signal_handler) - shutdown_trigger = signal_event.wait # type: ignore + shutdown_trigger = signal_event.wait lifespan = Lifespan(app, config, loop) @@ -84,7 +94,10 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 ssl_context = config.create_ssl_context() ssl_handshake_timeout = config.ssl_handshake_timeout - context = WorkerContext() + max_requests = None + if config.max_requests is not None: + max_requests = config.max_requests + randint(0, config.max_requests_jitter) + context = WorkerContext(max_requests) server_tasks: Set[asyncio.Task] = set() async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: @@ -136,7 +149,13 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)") try: - await raise_shutdown(shutdown_trigger) + async with TaskGroup() as task_group: + task_group.create_task(raise_shutdown(shutdown_trigger)) + task_group.create_task(raise_shutdown(context.terminate.wait)) + except BaseExceptionGroup as error: + _, other_errors = error.split((ShutdownError, KeyboardInterrupt)) + if other_errors is not None: + raise other_errors except (ShutdownError, KeyboardInterrupt): pass finally: diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index 025ec0a0..ce8e8147 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -47,10 +47,10 @@ async def run(self) -> None: server = parse_socket_addr(socket.family, socket.getsockname()) ssl_object = self.writer.get_extra_info("ssl_object") if ssl_object is not None: - ssl = True + tls = {} alpn_protocol = ssl_object.selected_alpn_protocol() else: - ssl = False + tls = None alpn_protocol = "http/1.1" async with TaskGroup(self.loop) as task_group: @@ -59,11 +59,12 @@ async def run(self) -> None: self.config, self.context, task_group, - ssl, + tls, client, server, self.protocol_send, alpn_protocol, + (self.reader, self.writer), ) await self.protocol.initiate() await self._start_idle() @@ -115,10 +116,16 @@ async def _close(self) -> None: try: self.writer.close() await self.writer.wait_closed() - except (BrokenPipeError, ConnectionAbortedError, ConnectionResetError, RuntimeError): + except ( + BrokenPipeError, + ConnectionAbortedError, + ConnectionResetError, + RuntimeError, + asyncio.CancelledError, + ): pass # Already closed - - await self._stop_idle() + finally: + await self._stop_idle() async def _initiate_server_close(self) -> None: await self.protocol.handle(Closed()) diff --git a/src/hypercorn/asyncio/worker_context.py b/src/hypercorn/asyncio/worker_context.py index fe9ad1c7..d16f76ba 100644 --- a/src/hypercorn/asyncio/worker_context.py +++ b/src/hypercorn/asyncio/worker_context.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import Type, Union +from typing import Optional, Type, Union from ..typing import Event @@ -26,9 +26,20 @@ def is_set(self) -> bool: class WorkerContext: event_class: Type[Event] = EventWrapper - def __init__(self) -> None: + def __init__(self, max_requests: Optional[int]) -> None: + self.max_requests = max_requests + self.requests = 0 + self.terminate = self.event_class() self.terminated = self.event_class() + async def mark_request(self) -> None: + if self.max_requests is None: + return + + self.requests += 1 + if self.requests > self.max_requests: + await self.terminate.set() + @staticmethod async def sleep(wait: Union[float, int]) -> None: return await asyncio.sleep(wait) diff --git a/src/hypercorn/config.py b/src/hypercorn/config.py index 26f50f00..f00c7d5e 100644 --- a/src/hypercorn/config.py +++ b/src/hypercorn/config.py @@ -84,6 +84,7 @@ class Config: include_date_header = True include_server_header = True keep_alive_timeout = 5 * SECONDS + keep_alive_max_requests = 1000 keyfile: Optional[str] = None keyfile_password: Optional[str] = None logconfig: Optional[str] = None @@ -91,6 +92,8 @@ class Config: logger_class = Logger loglevel: str = "INFO" max_app_queue_size: int = 10 + max_requests: Optional[int] = None + max_requests_jitter: int = 0 pid_path: Optional[str] = None server_names: List[str] = [] shutdown_timeout = 60 * SECONDS diff --git a/src/hypercorn/middleware/__init__.py b/src/hypercorn/middleware/__init__.py index 83ea29c7..e7f017c1 100644 --- a/src/hypercorn/middleware/__init__.py +++ b/src/hypercorn/middleware/__init__.py @@ -2,11 +2,13 @@ from .dispatcher import DispatcherMiddleware from .http_to_https import HTTPToHTTPSRedirectMiddleware +from .proxy_fix import ProxyFixMiddleware from .wsgi import AsyncioWSGIMiddleware, TrioWSGIMiddleware __all__ = ( "AsyncioWSGIMiddleware", "DispatcherMiddleware", "HTTPToHTTPSRedirectMiddleware", + "ProxyFixMiddleware", "TrioWSGIMiddleware", ) diff --git a/src/hypercorn/middleware/proxy_fix.py b/src/hypercorn/middleware/proxy_fix.py new file mode 100644 index 00000000..509941cf --- /dev/null +++ b/src/hypercorn/middleware/proxy_fix.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Callable, Iterable, Literal, Optional, Tuple + +from ..typing import ASGIFramework, Scope + + +class ProxyFixMiddleware: + def __init__( + self, + app: ASGIFramework, + mode: Literal["legacy", "modern"] = "legacy", + trusted_hops: int = 1, + ) -> None: + self.app = app + self.mode = mode + self.trusted_hops = trusted_hops + + async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: + if scope["type"] in {"http", "websocket"}: + scope = deepcopy(scope) + headers = scope["headers"] # type: ignore + client: Optional[str] = None + scheme: Optional[str] = None + host: Optional[str] = None + + if ( + self.mode == "modern" + and (value := _get_trusted_value(b"forwarded", headers, self.trusted_hops)) + is not None + ): + for part in value.split(";"): + if part.startswith("for="): + client = part[4:].strip() + elif part.startswith("host="): + host = part[5:].strip() + elif part.startswith("proto="): + scheme = part[6:].strip() + + else: + client = _get_trusted_value(b"x-forwarded-for", headers, self.trusted_hops) + scheme = _get_trusted_value(b"x-forwarded-proto", headers, self.trusted_hops) + host = _get_trusted_value(b"x-forwarded-host", headers, self.trusted_hops) + + if client is not None: + scope["client"] = (client, 0) # type: ignore + + if scheme is not None: + scope["scheme"] = scheme # type: ignore + + if host is not None: + headers = [ + (name, header_value) + for name, header_value in headers + if name.lower() != b"host" + ] + headers.append((b"host", host)) + scope["headers"] = headers # type: ignore + + await self.app(scope, receive, send) + + +def _get_trusted_value( + name: bytes, headers: Iterable[Tuple[bytes, bytes]], trusted_hops: int +) -> Optional[str]: + if trusted_hops == 0: + return None + + values = [] + for header_name, header_value in headers: + if header_name.lower() == name: + values.extend([value.decode("latin1").strip() for value in header_value.split(b",")]) + + if len(values) >= trusted_hops: + return values[-trusted_hops] + + return None diff --git a/src/hypercorn/protocol/__init__.py b/src/hypercorn/protocol/__init__.py index 39385681..e047da9d 100755 --- a/src/hypercorn/protocol/__init__.py +++ b/src/hypercorn/protocol/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Callable, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Optional, Tuple, Union from .h2 import H2Protocol from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol @@ -16,31 +16,34 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], alpn_protocol: Optional[str] = None, + transport=None, ) -> None: self.app = app self.config = config self.context = context self.task_group = task_group - self.ssl = ssl + self.tls = tls self.client = client self.server = server self.send = send self.protocol: Union[H11Protocol, H2Protocol] + self.transport = transport if alpn_protocol == "h2": self.protocol = H2Protocol( self.app, self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.send, + self.transport, ) else: self.protocol = H11Protocol( @@ -48,10 +51,11 @@ def __init__( self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.send, + self.transport, ) async def initiate(self) -> None: @@ -66,7 +70,7 @@ async def handle(self, event: Event) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.send, @@ -80,7 +84,7 @@ async def handle(self, event: Event) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.send, diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index 49f8b179..464c2caf 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Awaitable, Callable, cast, Optional, Tuple, Type, Union +from typing import Any, Awaitable, Callable, cast, Optional, Tuple, Type, Union import h11 @@ -84,10 +84,11 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], + transport=None, ) -> None: self.app = app self.can_read = context.event_class() @@ -97,11 +98,13 @@ def __init__( h11.SERVER, max_incomplete_event_size=self.config.h11_max_incomplete_size ) self.context = context + self.keep_alive_requests = 0 self.send = send self.server = server - self.ssl = ssl + self.tls = tls self.stream: Optional[Union[HTTPStream, WSStream]] = None self.task_group = task_group + self.transport = transport async def initiate(self) -> None: pass @@ -154,10 +157,11 @@ async def _handle_events(self) -> None: try: event = self.connection.next_event() - except h11.RemoteProtocolError: + except h11.RemoteProtocolError as error: if self.connection.our_state in {h11.IDLE, h11.SEND_RESPONSE}: - await self._send_error_response(400) + await self._send_error_response(error.error_status_hint) await self.send(Closed()) + raise break else: if isinstance(event, h11.Request): @@ -200,7 +204,7 @@ async def _create_stream(self, request: h11.Request) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.stream_send, @@ -213,11 +217,12 @@ async def _create_stream(self, request: h11.Request) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.stream_send, STREAM_ID, + self.transport, ) if self.config.h11_pass_raw_headers: @@ -234,6 +239,8 @@ async def _create_stream(self, request: h11.Request) -> None: raw_path=request.target, ) ) + self.keep_alive_requests += 1 + await self.context.mark_request() async def _send_h11_event(self, event: H11SendableEvent) -> None: try: @@ -264,6 +271,7 @@ async def _maybe_recycle(self) -> None: not self.context.terminated.is_set() and self.connection.our_state is h11.DONE and self.connection.their_state is h11.DONE + and self.keep_alive_requests <= self.config.keep_alive_max_requests ): try: self.connection.start_next_cycle() diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 6e76d493..8821bf40 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union import h2 import h2.connection @@ -84,10 +84,11 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], + transport=None, ) -> None: self.app = app self.client = client @@ -109,14 +110,16 @@ def __init__( }, ) + self.keep_alive_requests = 0 self.send = send self.server = server - self.ssl = ssl + self.tls = tls self.streams: Dict[int, Union[HTTPStream, WSStream]] = {} # The below are used by the sending task self.has_data = self.context.event_class() self.priority = priority.PriorityTree() self.stream_buffers: Dict[int, StreamBuffer] = {} + self.transport = transport @property def idle(self) -> bool: @@ -184,6 +187,7 @@ async def handle(self, event: Event) -> None: except h2.exceptions.ProtocolError: await self._flush() await self.send(Closed()) + raise else: await self._handle_events(events) elif isinstance(event, Closed): @@ -244,6 +248,9 @@ async def _handle_events(self, events: List[h2.events.Event]) -> None: else: await self._create_stream(event) await self.send(Updated(idle=False)) + + if self.keep_alive_requests > self.config.keep_alive_max_requests: + self.connection.close_connection() elif isinstance(event, h2.events.DataReceived): await self.streams[event.stream_id].handle( Body(stream_id=event.stream_id, data=event.data) @@ -313,7 +320,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.stream_send, @@ -325,7 +332,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: self.config, self.context, self.task_group, - self.ssl, + self.tls, self.client, self.server, self.stream_send, @@ -349,6 +356,8 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: raw_path=raw_path, ) ) + self.keep_alive_requests += 1 + await self.context.mark_request() async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] @@ -374,6 +383,7 @@ async def _create_server_push( event.headers = request_headers await self._create_stream(event) await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) + self.keep_alive_requests += 1 async def _close_stream(self, stream_id: int) -> None: if stream_id in self.streams: diff --git a/src/hypercorn/protocol/h3.py b/src/hypercorn/protocol/h3.py index 88d9a4d3..151c0667 100644 --- a/src/hypercorn/protocol/h3.py +++ b/src/hypercorn/protocol/h3.py @@ -125,6 +125,7 @@ async def _create_stream(self, request: HeadersReceived) -> None: raw_path=raw_path, ) ) + await self.context.mark_request() async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index d244e7c3..c4b70773 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -2,7 +2,7 @@ from enum import auto, Enum from time import time -from typing import Awaitable, Callable, Optional, Tuple +from typing import Any, Awaitable, Callable, Optional, Tuple from urllib.parse import unquote from .events import Body, EndBody, Event, InformationalResponse, Request, Response, StreamClosed @@ -42,11 +42,12 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], stream_id: int, + transport=None, ) -> None: self.app = app self.client = client @@ -56,12 +57,14 @@ def __init__( self.response: HTTPResponseStartEvent self.scope: HTTPScope self.send = send - self.scheme = "https" if ssl else "http" + self.scheme = "https" if tls is not None else "http" + self.tls = tls self.server = server self.start_time: float self.state = ASGIHTTPState.REQUEST self.stream_id = stream_id self.task_group = task_group + self.transport = transport @property def idle(self) -> bool: @@ -94,6 +97,12 @@ async def handle(self, event: Event) -> None: if event.http_version in EARLY_HINTS_VERSIONS: self.scope["extensions"]["http.response.early_hint"] = {} + if self.tls is not None: + self.scope["extensions"]["tls"] = self.tls + + if self.transport is not None: + self.scope["extensions"]["_transport"] = self.transport + if valid_server_name(self.config, event): self.app_put = await self.task_group.spawn_app( self.app, self.config, self.scope, self.app_send diff --git a/src/hypercorn/protocol/ws_stream.py b/src/hypercorn/protocol/ws_stream.py index 9011999e..b0b99c89 100644 --- a/src/hypercorn/protocol/ws_stream.py +++ b/src/hypercorn/protocol/ws_stream.py @@ -3,7 +3,7 @@ from enum import auto, Enum from io import BytesIO, StringIO from time import time -from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple, Union from urllib.parse import unquote from wsproto.connection import Connection, ConnectionState, ConnectionType @@ -18,7 +18,7 @@ from wsproto.extensions import Extension, PerMessageDeflate from wsproto.frame_protocol import CloseReason from wsproto.handshake import server_extensions_handshake, WEBSOCKET_VERSION -from wsproto.utilities import generate_accept_token, split_comma_header +from wsproto.utilities import generate_accept_token, LocalProtocolError, split_comma_header from .events import Body, Data, EndBody, EndData, Event, Request, Response, StreamClosed from ..config import Config @@ -167,7 +167,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, - ssl: bool, + tls: Optional[dict[str, Any]], client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], @@ -185,7 +185,7 @@ def __init__( self.scope: WebsocketScope self.send = send # RFC 8441 for HTTP/2 says use http or https, ASGI says ws or wss - self.scheme = "wss" if ssl else "ws" + self.scheme = "wss" if tls is not None else "ws" self.server = server self.start_time: float self.state = ASGIWebsocketState.HANDSHAKE @@ -333,8 +333,12 @@ async def _send_error_response(self, status_code: int) -> None: ) async def _send_wsproto_event(self, event: WSProtoEvent) -> None: - data = self.connection.send(event) - await self.send(Data(stream_id=self.stream_id, data=data)) + try: + data = self.connection.send(event) + except LocalProtocolError: + pass + else: + await self.send(Data(stream_id=self.stream_id, data=data)) async def _accept(self, message: WebsocketAcceptEvent) -> None: self.state = ASGIWebsocketState.CONNECTED diff --git a/src/hypercorn/run.py b/src/hypercorn/run.py index a6d2fb08..cfe801aa 100644 --- a/src/hypercorn/run.py +++ b/src/hypercorn/run.py @@ -4,6 +4,7 @@ import signal import time from multiprocessing import get_context +from multiprocessing.connection import wait from multiprocessing.context import BaseContext from multiprocessing.process import BaseProcess from multiprocessing.synchronize import Event as EventType @@ -12,10 +13,10 @@ from .config import Config, Sockets from .typing import WorkerFunc -from .utils import load_application, wait_for_changes, write_pid_file +from .utils import check_for_updates, files_to_watch, load_application, write_pid_file -def run(config: Config) -> None: +def run(config: Config) -> int: if config.pid_path is not None: write_pid_file(config.pid_path) @@ -40,64 +41,82 @@ def run(config: Config) -> None: if config.use_reloader and config.workers == 0: raise RuntimeError("Cannot reload without workers") - if config.use_reloader or config.workers == 0: - # Load the application so that the correct paths are checked for - # changes, but only when the reloader is being used. - load_application(config.application_path, config.wsgi_max_body_size) - + exitcode = 0 if config.workers == 0: worker_func(config, sockets) else: + if config.use_reloader: + # Load the application so that the correct paths are checked for + # changes, but only when the reloader is being used. + load_application(config.application_path, config.wsgi_max_body_size) + ctx = get_context("spawn") active = True + shutdown_event = ctx.Event() + + def shutdown(*args: Any) -> None: + nonlocal active, shutdown_event + shutdown_event.set() + active = False + + processes: List[BaseProcess] = [] while active: # Ignore SIGINT before creating the processes, so that they # inherit the signal handling. This means that the shutdown # function controls the shutdown. signal.signal(signal.SIGINT, signal.SIG_IGN) - shutdown_event = ctx.Event() - processes = start_processes(config, worker_func, sockets, shutdown_event, ctx) - - def shutdown(*args: Any) -> None: - nonlocal active, shutdown_event - shutdown_event.set() - active = False + _populate(processes, config, worker_func, sockets, shutdown_event, ctx) for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}: if hasattr(signal, signal_name): signal.signal(getattr(signal, signal_name), shutdown) if config.use_reloader: - wait_for_changes(shutdown_event) - shutdown_event.set() - # Recreate the sockets to be used again in the next - # iteration of the loop. - sockets = config.create_sockets() + files = files_to_watch() + while True: + finished = wait((process.sentinel for process in processes), timeout=1) + updated = check_for_updates(files) + if updated: + shutdown_event.set() + for process in processes: + process.join() + shutdown_event.clear() + break + if len(finished) > 0: + break else: + wait(process.sentinel for process in processes) + + exitcode = _join_exited(processes) + if exitcode != 0: + shutdown_event.set() active = False - for process in processes: - process.join() for process in processes: process.terminate() + exitcode = _join_exited(processes) if exitcode != 0 else exitcode + for sock in sockets.secure_sockets: sock.close() + for sock in sockets.insecure_sockets: sock.close() + return exitcode -def start_processes( + +def _populate( + processes: List[BaseProcess], config: Config, worker_func: WorkerFunc, sockets: Sockets, shutdown_event: EventType, ctx: BaseContext, -) -> List[BaseProcess]: - processes = [] - for _ in range(config.workers): +) -> None: + for _ in range(config.workers - len(processes)): process = ctx.Process( # type: ignore target=worker_func, kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets}, @@ -112,4 +131,15 @@ def start_processes( processes.append(process) if platform.system() == "Windows": time.sleep(0.1) - return processes + + +def _join_exited(processes: List[BaseProcess]) -> int: + exitcode = 0 + for index in reversed(range(len(processes))): + worker = processes[index] + if worker.exitcode is not None: + worker.join() + exitcode = worker.exitcode if exitcode == 0 else exitcode + del processes[index] + + return exitcode diff --git a/src/hypercorn/trio/run.py b/src/hypercorn/trio/run.py index d8721bbb..2cfe5db4 100644 --- a/src/hypercorn/trio/run.py +++ b/src/hypercorn/trio/run.py @@ -3,6 +3,7 @@ import sys from functools import partial from multiprocessing.synchronize import Event as EventType +from random import randint from typing import Awaitable, Callable, Optional import trio @@ -37,7 +38,10 @@ async def worker_serve( config.set_statsd_logger_class(StatsdLogger) lifespan = Lifespan(app, config) - context = WorkerContext() + max_requests = None + if config.max_requests is not None: + max_requests = config.max_requests + randint(0, config.max_requests_jitter) + context = WorkerContext(max_requests) async with trio.open_nursery() as lifespan_nursery: await lifespan_nursery.start(lifespan.handle_lifespan) @@ -82,6 +86,7 @@ async def worker_serve( async with trio.open_nursery(strict_exception_groups=True) as nursery: if shutdown_trigger is not None: nursery.start_soon(raise_shutdown, shutdown_trigger) + nursery.start_soon(raise_shutdown, context.terminate.wait) nursery.start_soon( partial( diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index dbcc7a12..e723f507 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ssl from math import inf from typing import Any, Generator, Optional @@ -42,11 +43,17 @@ async def run(self) -> None: return # Handshake failed alpn_protocol = self.stream.selected_alpn_protocol() socket = self.stream.transport_stream.socket - ssl = True + + tls = {"alpn_protocol": alpn_protocol} + client_certificate = self.stream.getpeercert(binary_form=False) + if client_certificate: + tls["client_cert_name"] = ", ".join( + [f"{part[0][0]}={part[0][1]}" for part in client_certificate["subject"]] + ) except AttributeError: # Not SSL alpn_protocol = "http/1.1" socket = self.stream.socket - ssl = False + tls = None try: client = parse_socket_addr(socket.family, socket.getpeername()) @@ -59,11 +66,12 @@ async def run(self) -> None: self.config, self.context, task_group, - ssl, + tls, client, server, self.protocol_send, alpn_protocol, + self.stream, ) await self.protocol.initiate() await self._start_idle() diff --git a/src/hypercorn/trio/worker_context.py b/src/hypercorn/trio/worker_context.py index bcfa1a51..c09c4fb6 100644 --- a/src/hypercorn/trio/worker_context.py +++ b/src/hypercorn/trio/worker_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Type, Union +from typing import Optional, Type, Union import trio @@ -27,9 +27,20 @@ def is_set(self) -> bool: class WorkerContext: event_class: Type[Event] = EventWrapper - def __init__(self) -> None: + def __init__(self, max_requests: Optional[int]) -> None: + self.max_requests = max_requests + self.requests = 0 + self.terminate = self.event_class() self.terminated = self.event_class() + async def mark_request(self) -> None: + if self.max_requests is None: + return + + self.requests += 1 + if self.requests > self.max_requests: + await self.terminate.set() + @staticmethod async def sleep(wait: Union[float, int]) -> None: return await trio.sleep(wait) diff --git a/src/hypercorn/typing.py b/src/hypercorn/typing.py index 1299a776..2ebb711d 100644 --- a/src/hypercorn/typing.py +++ b/src/hypercorn/typing.py @@ -290,8 +290,12 @@ def is_set(self) -> bool: class WorkerContext(Protocol): event_class: Type[Event] + terminate: Event terminated: Event + async def mark_request(self) -> None: + ... + @staticmethod async def sleep(wait: Union[float, int]) -> None: ... diff --git a/src/hypercorn/utils.py b/src/hypercorn/utils.py index 5629ff71..39249c53 100644 --- a/src/hypercorn/utils.py +++ b/src/hypercorn/utils.py @@ -4,7 +4,6 @@ import os import socket import sys -import time from enum import Enum from importlib import import_module from multiprocessing.synchronize import Event as EventType @@ -133,7 +132,7 @@ def wrap_app( return WSGIWrapper(cast(WSGIFramework, app), wsgi_max_body_size) -def wait_for_changes(shutdown_event: EventType) -> None: +def files_to_watch() -> Dict[Path, float]: last_updates: Dict[Path, float] = {} for module in list(sys.modules.values()): filename = getattr(module, "__file__", None) @@ -144,24 +143,21 @@ def wait_for_changes(shutdown_event: EventType) -> None: last_updates[Path(filename)] = path.stat().st_mtime except (FileNotFoundError, NotADirectoryError): pass + return last_updates - while not shutdown_event.is_set(): - time.sleep(1) - for index, (path, last_mtime) in enumerate(last_updates.items()): - if index % 10 == 0: - # Yield to the event loop - time.sleep(0) - - try: - mtime = path.stat().st_mtime - except FileNotFoundError: - return +def check_for_updates(files: Dict[Path, float]) -> bool: + for path, last_mtime in files.items(): + try: + mtime = path.stat().st_mtime + except FileNotFoundError: + return True + else: + if mtime > last_mtime: + return True else: - if mtime > last_mtime: - return - else: - last_updates[path] = mtime + files[path] = mtime + return False async def raise_shutdown(shutdown_event: Callable[..., Awaitable]) -> None: @@ -185,7 +181,7 @@ def write_pid_file(pid_path: str) -> None: def parse_socket_addr(family: int, address: tuple) -> Optional[Tuple[str, int]]: if family == socket.AF_INET: - return address # type: ignore + return address elif family == socket.AF_INET6: return (address[0], address[1]) else: diff --git a/tests/asyncio/test_keep_alive.py b/tests/asyncio/test_keep_alive.py index 6b357f8f..a46f4cfd 100644 --- a/tests/asyncio/test_keep_alive.py +++ b/tests/asyncio/test_keep_alive.py @@ -50,7 +50,7 @@ async def _server(event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[TCPSe ASGIWrapper(slow_framework), event_loop, config, - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) diff --git a/tests/asyncio/test_sanity.py b/tests/asyncio/test_sanity.py index 287cd06d..cde29297 100644 --- a/tests/asyncio/test_sanity.py +++ b/tests/asyncio/test_sanity.py @@ -21,7 +21,7 @@ async def test_http1_request(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -66,7 +66,7 @@ async def test_http1_request(event_loop: asyncio.AbstractEventLoop) -> None: reason=b"", ), h11.Data(data=b"Hello & Goodbye"), - h11.EndOfMessage(headers=[]), # type: ignore + h11.EndOfMessage(headers=[]), ] server.reader.close() # type: ignore await task @@ -78,7 +78,7 @@ async def test_http1_websocket(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -115,7 +115,7 @@ async def test_http2_request(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) @@ -178,7 +178,7 @@ async def test_http2_websocket(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) diff --git a/tests/asyncio/test_tcp_server.py b/tests/asyncio/test_tcp_server.py index afe00c20..1dfd4212 100644 --- a/tests/asyncio/test_tcp_server.py +++ b/tests/asyncio/test_tcp_server.py @@ -18,7 +18,7 @@ async def test_completes_on_closed(event_loop: asyncio.AbstractEventLoop) -> Non ASGIWrapper(echo_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -34,7 +34,7 @@ async def test_complets_on_half_close(event_loop: asyncio.AbstractEventLoop) -> ASGIWrapper(echo_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) diff --git a/tests/middleware/test_dispatcher.py b/tests/middleware/test_dispatcher.py index 1c3d7a28..dbb3f43e 100644 --- a/tests/middleware/test_dispatcher.py +++ b/tests/middleware/test_dispatcher.py @@ -36,9 +36,9 @@ async def send(message: dict) -> None: nonlocal sent_events sent_events.append(message) - await app({**http_scope, **{"path": "/api/x/b"}}, None, send) - await app({**http_scope, **{"path": "/api/b"}}, None, send) - await app({**http_scope, **{"path": "/"}}, None, send) + await app({**http_scope, **{"path": "/api/x/b"}}, None, send) # type: ignore + await app({**http_scope, **{"path": "/api/b"}}, None, send) # type: ignore + await app({**http_scope, **{"path": "/"}}, None, send) # type: ignore assert sent_events == [ {"type": "http.response.start", "status": 200, "headers": [(b"content-length", b"7")]}, {"type": "http.response.body", "body": b"apix-/b"}, diff --git a/tests/middleware/test_proxy_fix.py b/tests/middleware/test_proxy_fix.py new file mode 100644 index 00000000..2a43b589 --- /dev/null +++ b/tests/middleware/test_proxy_fix.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from hypercorn.middleware import ProxyFixMiddleware +from hypercorn.typing import HTTPScope + + +@pytest.mark.asyncio +async def test_proxy_fix_legacy() -> None: + mock = AsyncMock() + app = ProxyFixMiddleware(mock) + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "2", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [ + (b"x-forwarded-for", b"127.0.0.1"), + (b"x-forwarded-for", b"127.0.0.2"), + (b"x-forwarded-proto", b"http,https"), + ], + "client": ("127.0.0.3", 80), + "server": None, + "extensions": {}, + } + await app(scope, None, None) + mock.assert_called() + assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0) + assert mock.call_args[0][0]["scheme"] == "https" + + +@pytest.mark.asyncio +async def test_proxy_fix_modern() -> None: + mock = AsyncMock() + app = ProxyFixMiddleware(mock, mode="modern") + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "2", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [ + (b"forwarded", b"for=127.0.0.1;proto=http,for=127.0.0.2;proto=https"), + ], + "client": ("127.0.0.3", 80), + "server": None, + "extensions": {}, + } + await app(scope, None, None) + mock.assert_called() + assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0) + assert mock.call_args[0][0]["scheme"] == "https" diff --git a/tests/protocol/test_h11.py b/tests/protocol/test_h11.py index 86bb00a4..09e85b78 100755 --- a/tests/protocol/test_h11.py +++ b/tests/protocol/test_h11.py @@ -35,6 +35,8 @@ async def _protocol(monkeypatch: MonkeyPatch) -> H11Protocol: monkeypatch.setattr(hypercorn.protocol.h11, "HTTPStream", MockHTTPStream) context = Mock() context.event_class.return_value = AsyncMock(spec=IOEvent) + context.mark_request = AsyncMock() + context.terminate = context.event_class() context.terminated = context.event_class() context.terminated.is_set.return_value = False return H11Protocol(AsyncMock(), Config(), context, AsyncMock(), False, None, None, AsyncMock()) @@ -103,6 +105,18 @@ async def test_protocol_send_body(protocol: H11Protocol) -> None: ] +@pytest.mark.asyncio +async def test_protocol_keep_alive_max_requests(protocol: H11Protocol) -> None: + data = b"GET / HTTP/1.1\r\nHost: hypercorn\r\n\r\n" + protocol.config.keep_alive_max_requests = 0 + await protocol.handle(RawData(data=data)) + await protocol.stream_send(Response(stream_id=1, status_code=200, headers=[])) + await protocol.stream_send(EndBody(stream_id=1)) + await protocol.stream_send(StreamClosed(stream_id=1)) + protocol.send.assert_called() # type: ignore + assert protocol.send.call_args_list[3] == call(Closed()) # type: ignore + + @pytest.mark.asyncio @pytest.mark.parametrize("keep_alive, expected", [(True, Updated(idle=True)), (False, Closed())]) async def test_protocol_send_stream_closed( @@ -302,7 +316,7 @@ async def test_protocol_handle_max_incomplete(monkeypatch: MonkeyPatch) -> None: assert protocol.send.call_args_list == [ # type: ignore call( RawData( - data=b"HTTP/1.1 400 \r\ncontent-length: 0\r\nconnection: close\r\n" + data=b"HTTP/1.1 431 \r\ncontent-length: 0\r\nconnection: close\r\n" b"date: Thu, 01 Jan 1970 01:23:20 GMT\r\nserver: hypercorn-h11\r\n\r\n" ) ), diff --git a/tests/protocol/test_h2.py b/tests/protocol/test_h2.py index c44f39ae..cec6c263 100644 --- a/tests/protocol/test_h2.py +++ b/tests/protocol/test_h2.py @@ -4,6 +4,8 @@ from unittest.mock import call, Mock import pytest +from h2.connection import H2Connection +from h2.events import ConnectionTerminated from hypercorn.asyncio.worker_context import EventWrapper, WorkerContext from hypercorn.config import Config @@ -73,8 +75,29 @@ async def test_stream_buffer_complete(event_loop: asyncio.AbstractEventLoop) -> @pytest.mark.asyncio async def test_protocol_handle_protocol_error() -> None: protocol = H2Protocol( - Mock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock() + Mock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock() ) await protocol.handle(RawData(data=b"broken nonsense\r\n\r\n")) protocol.send.assert_awaited() # type: ignore assert protocol.send.call_args_list == [call(Closed())] # type: ignore + + +@pytest.mark.asyncio +async def test_protocol_keep_alive_max_requests() -> None: + protocol = H2Protocol( + Mock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock() + ) + protocol.config.keep_alive_max_requests = 0 + client = H2Connection() + client.initiate_connection() + headers = [ + (":method", "GET"), + (":path", "/reqinfo"), + (":authority", "hypercorn"), + (":scheme", "https"), + ] + client.send_headers(1, headers, end_stream=True) + await protocol.handle(RawData(data=client.data_to_send())) + protocol.send.assert_awaited() # type: ignore + events = client.receive_data(protocol.send.call_args_list[1].args[0].data) # type: ignore + assert isinstance(events[-1], ConnectionTerminated) diff --git a/tests/protocol/test_http_stream.py b/tests/protocol/test_http_stream.py index 24af5969..6f656de0 100644 --- a/tests/protocol/test_http_stream.py +++ b/tests/protocol/test_http_stream.py @@ -31,7 +31,7 @@ @pytest_asyncio.fixture(name="stream") # type: ignore[misc] async def _stream() -> HTTPStream: stream = HTTPStream( - AsyncMock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock(), 1 + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) stream.app_put = AsyncMock() stream.config._log = AsyncMock(spec=Logger) diff --git a/tests/protocol/test_ws_stream.py b/tests/protocol/test_ws_stream.py index 5f595828..05403130 100644 --- a/tests/protocol/test_ws_stream.py +++ b/tests/protocol/test_ws_stream.py @@ -165,7 +165,7 @@ def test_handshake_accept_additional_headers() -> None: @pytest_asyncio.fixture(name="stream") # type: ignore[misc] async def _stream() -> WSStream: stream = WSStream( - AsyncMock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock(), 1 + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) stream.task_group.spawn_app.return_value = AsyncMock() # type: ignore stream.app_put = AsyncMock() diff --git a/tests/test_app_wrappers.py b/tests/test_app_wrappers.py index bb7b5897..c68ba0cb 100644 --- a/tests/test_app_wrappers.py +++ b/tests/test_app_wrappers.py @@ -61,8 +61,28 @@ async def _send(message: ASGISendEvent) -> None: ] +async def _run_app(app: WSGIWrapper, scope: HTTPScope, body: bytes = b"") -> List[ASGISendEvent]: + queue: asyncio.Queue = asyncio.Queue() + await queue.put({"type": "http.request", "body": body}) + + messages = [] + + async def _send(message: ASGISendEvent) -> None: + nonlocal messages + messages.append(message) + + event_loop = asyncio.get_running_loop() + + def _call_soon(func: Callable, *args: Any) -> Any: + future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) + return future.result() + + await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + return messages + + @pytest.mark.asyncio -async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_wsgi_asyncio() -> None: app = WSGIWrapper(echo_body, 2**16) scope: HTTPScope = { "http_version": "1.1", @@ -79,20 +99,7 @@ async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None: "server": None, "extensions": {}, } - queue: asyncio.Queue = asyncio.Queue() - await queue.put({"type": "http.request"}) - - messages = [] - - async def _send(message: ASGISendEvent) -> None: - nonlocal messages - messages.append(message) - - def _call_soon(func: Callable, *args: Any) -> Any: - future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) - return future.result() - - await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + messages = await _run_app(app, scope) assert messages == [ { "headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")], @@ -105,7 +112,7 @@ def _call_soon(func: Callable, *args: Any) -> Any: @pytest.mark.asyncio -async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_max_body_size() -> None: app = WSGIWrapper(echo_body, 4) scope: HTTPScope = { "http_version": "1.1", @@ -122,25 +129,39 @@ async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None: "server": None, "extensions": {}, } - queue: asyncio.Queue = asyncio.Queue() - await queue.put({"type": "http.request", "body": b"abcde"}) - messages = [] - - async def _send(message: ASGISendEvent) -> None: - nonlocal messages - messages.append(message) - - def _call_soon(func: Callable, *args: Any) -> Any: - future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) - return future.result() - - await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + messages = await _run_app(app, scope, b"abcde") assert messages == [ {"headers": [], "status": 400, "type": "http.response.start"}, {"body": bytearray(b""), "type": "http.response.body", "more_body": False}, ] +def no_start_response(environ: dict, start_response: Callable) -> List[bytes]: + return [b"result"] + + +@pytest.mark.asyncio +async def test_no_start_response() -> None: + app = WSGIWrapper(no_start_response, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + } + with pytest.raises(RuntimeError): + await _run_app(app, scope) + + def test_build_environ_encoding() -> None: scope: HTTPScope = { "http_version": "1.0", diff --git a/tests/trio/test_keep_alive.py b/tests/trio/test_keep_alive.py index d30d82db..6bed437f 100644 --- a/tests/trio/test_keep_alive.py +++ b/tests/trio/test_keep_alive.py @@ -47,7 +47,7 @@ def _client_stream( config.keep_alive_timeout = KEEP_ALIVE_TIMEOUT client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(None), server_stream) nursery.start_soon(server.run) yield client_stream diff --git a/tests/trio/test_sanity.py b/tests/trio/test_sanity.py index 3828e37b..b5bf75ba 100644 --- a/tests/trio/test_sanity.py +++ b/tests/trio/test_sanity.py @@ -25,7 +25,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = h11.Connection(h11.CLIENT) await client_stream.send_all( @@ -68,7 +68,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: reason=b"", ), h11.Data(data=b"Hello & Goodbye"), - h11.EndOfMessage(headers=[]), # type: ignore + h11.EndOfMessage(headers=[]), ] @@ -76,7 +76,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = wsproto.WSConnection(wsproto.ConnectionType.CLIENT) await client_stream.send_all(client.send(wsproto.events.Request(host="hypercorn", target="/"))) @@ -103,7 +103,7 @@ async def test_http2_request(nursery: trio._core._run.Nursery) -> None: server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) server_stream.do_handshake = AsyncMock() server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = h2.connection.H2Connection() client.initiate_connection() @@ -158,7 +158,7 @@ async def test_http2_websocket(nursery: trio._core._run.Nursery) -> None: server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) server_stream.do_handshake = AsyncMock() server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) h2_client = h2.connection.H2Connection() h2_client.initiate_connection()