Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/6463.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allowed application-level ``on_response_prepare`` callbacks to run for parser-error responses so applications can adjust response headers such as ``Server`` consistently.
10 changes: 6 additions & 4 deletions aiohttp/web_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,13 @@ def _prepare_middleware(self) -> Iterator[Middleware]:
yield _fix_request_current_app(self)

async def _handle(self, request: Request) -> StreamResponse:
match_info = await self._router.resolve(request)
match_info.add_app(self)
match_info.freeze()
match_info = request._match_info
if match_info is None:
match_info = await self._router.resolve(request)
match_info.add_app(self)
match_info.freeze()

request._match_info = match_info
request._match_info = match_info

if request.headers.get(hdrs.EXPECT):
resp = await match_info.expect_handler(request)
Expand Down
13 changes: 9 additions & 4 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,11 +612,10 @@ async def start(self) -> None:

manager.requests_count += 1
writer = StreamWriter(self, loop)
if not isinstance(message, _ErrInfo):
request_handler = self._request_handler
else:
err_info = None
if isinstance(message, _ErrInfo):
# make request_factory work
request_handler = self._make_error_handler(message)
err_info = message
message = ERROR

# Important don't hold a reference to the current task
Expand All @@ -629,6 +628,12 @@ async def start(self) -> None:
writer,
self._task_handler or asyncio.current_task(loop), # type: ignore[arg-type]
)
if err_info is not None and getattr(request, "_match_info", None) is None:
request_handler: _RequestHandler[_Request] = cast(
_RequestHandler[_Request], self._make_error_handler(err_info)
)
Comment on lines +631 to +634

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this end up doing about the same thing as the old code? Seems like it's not going to solve the features requested..

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the difference is where this lets the app request handler back in for the parser-error request.

For the default AppRunner path, web_runner._make_request() sees message is ERROR and attaches request._match_info = MatchInfoError(HTTPBadRequest()). With this change, RequestHandler.start() builds that request first; when _match_info exists, it uses the app-level request handler instead of the protocol _make_error_handler fallback.

Then Application._handle() takes the existing system-route match info, builds middleware for SystemRoute, and returns the HTTPBadRequest response through the normal app path. finish_response() still calls resp.prepare(request), so on_response_prepare fires. The fallback error handler remains for custom request factories that do not attach match info.

I rechecked the targeted coverage locally:

PYTHONPATH=$PWD uv run --no-project --with pytest --with pytest-aiohttp --with pytest-mock --with pytest-timeout --with trustme --with cryptography --with proxy.py --with gunicorn --with uvloop --with dirty-equals --with coverage pytest tests/test_web_functional.py::test_signal_on_error_handler tests/test_web_functional.py::test_signal_on_parser_error_handler tests/test_web_functional.py::test_bad_method_for_c_http_parser_not_hangs -q

Result: 2 passed, 1 skipped; the skipped one is the local checkout missing the C HTTP parser.

else:
request_handler = self._request_handler
try:
# a new task is used for copy context vars (#3406)
coro = self._handle_request(request, start, request_handler)
Expand Down
12 changes: 10 additions & 2 deletions aiohttp/web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from .streams import StreamReader
from .typedefs import PathLike
from .web_app import Application
from .web_exceptions import HTTPBadRequest
from .web_log import AccessLogger
from .web_protocol import RequestHandler
from .web_protocol import ERROR, RequestHandler
from .web_request import BaseRequest, Request
from .web_server import Server
from .web_urldispatcher import MatchInfoError

try:
from ssl import SSLContext
Expand Down Expand Up @@ -439,7 +441,7 @@ def _make_request(
_cls: type[Request] = Request,
) -> Request:
loop = asyncio.get_running_loop()
return _cls(
request = _cls(
message,
payload,
protocol,
Expand All @@ -448,6 +450,12 @@ def _make_request(
loop,
client_max_size=self.app._client_max_size,
)
if message is ERROR:
match_info = MatchInfoError(HTTPBadRequest())
match_info.add_app(self._app)
match_info.freeze()
request._match_info = match_info
return request

async def _cleanup_server(self) -> None:
await self._app.cleanup()
36 changes: 36 additions & 0 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,6 +2227,42 @@ async def on_prepare(request: web.Request, response: web.StreamResponse) -> None
resp.release()


async def test_signal_on_parser_error_handler(
aiohttp_server: AiohttpServer,
) -> None:
async def middleware(
request: web.Request,
handler: Handler,
) -> web.StreamResponse:
try:
return await handler(request)
except web.HTTPBadRequest as exc:
exc.headers["X-Middleware"] = "val"
raise

async def on_prepare(request: web.Request, response: web.StreamResponse) -> None:
response.headers["X-Custom"] = "val"
response.headers.pop("Server", None)

app = web.Application(middlewares=[middleware])
app.on_response_prepare.append(on_prepare)

server = await aiohttp_server(app)
reader, writer = await asyncio.open_connection(server.host, server.port)
try:
writer.write(b"GE T / HTTP/1.1\r\nHost: localhost\r\n\r\n")
await writer.drain()
response = await asyncio.wait_for(reader.readuntil(b"\r\n\r\n"), timeout=5)
finally:
writer.close()
await writer.wait_closed()

assert b"400 Bad Request" in response
assert b"X-Middleware: val" in response
assert b"X-Custom: val" in response
assert b"\r\nServer:" not in response


@pytest.mark.skipif(
"HttpRequestParserC" not in dir(aiohttp.http_parser),
reason="C based HTTP parser not available",
Expand Down
Loading