Skip to content

Commit 78d1be5

Browse files
authored
Fix client connection header not reflecting connector force_close value (#10003)
1 parent a334eef commit 78d1be5

File tree

5 files changed

+34
-61
lines changed

5 files changed

+34
-61
lines changed

CHANGES/10003.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed the HTTP client not considering the connector's ``force_close`` value when setting the ``Connection`` header -- by :user:`bdraco`.

aiohttp/client_reqrep.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -572,15 +572,6 @@ def update_proxy(
572572
proxy_headers = CIMultiDict(proxy_headers)
573573
self.proxy_headers = proxy_headers
574574

575-
def keep_alive(self) -> bool:
576-
if self.version >= HttpVersion11:
577-
return self.headers.get(hdrs.CONNECTION) != "close"
578-
if self.version == HttpVersion10:
579-
# no headers means we close for Http 1.0
580-
return self.headers.get(hdrs.CONNECTION) == "keep-alive"
581-
# keep alive not supported at all
582-
return False
583-
584575
async def write_bytes(
585576
self, writer: AbstractStreamWriter, conn: "Connection"
586577
) -> None:
@@ -678,21 +669,15 @@ async def send(self, conn: "Connection") -> "ClientResponse":
678669
):
679670
self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"
680671

681-
# set the connection header
682-
connection = self.headers.get(hdrs.CONNECTION)
683-
if not connection:
684-
if self.keep_alive():
685-
if self.version == HttpVersion10:
686-
connection = "keep-alive"
687-
else:
688-
if self.version == HttpVersion11:
689-
connection = "close"
690-
691-
if connection is not None:
692-
self.headers[hdrs.CONNECTION] = connection
672+
v = self.version
673+
if hdrs.CONNECTION not in self.headers:
674+
if conn._connector.force_close:
675+
if v == HttpVersion11:
676+
self.headers[hdrs.CONNECTION] = "close"
677+
elif v == HttpVersion10:
678+
self.headers[hdrs.CONNECTION] = "keep-alive"
693679

694680
# status + headers
695-
v = self.version
696681
status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
697682
await writer.write_headers(status_line, self.headers)
698683
task: Optional["asyncio.Task[None]"]

tests/test_benchmarks_client_request.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,16 @@ async def _drain_helper(self) -> None:
9999
def start_timeout(self) -> None:
100100
"""Swallow start_timeout."""
101101

102+
class MockConnector:
103+
104+
def __init__(self) -> None:
105+
self.force_close = False
106+
102107
class MockConnection:
103108
def __init__(self) -> None:
104109
self.transport = None
105110
self.protocol = MockProtocol()
111+
self._connector = MockConnector()
106112

107113
conn = MockConnection()
108114

tests/test_client_request.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
_gen_default_accept_encoding,
3434
)
3535
from aiohttp.connector import Connection
36-
from aiohttp.http import HttpVersion
36+
from aiohttp.http import HttpVersion10, HttpVersion11
3737
from aiohttp.test_utils import make_mocked_coro
3838
from aiohttp.typedefs import LooseCookies
3939

@@ -156,30 +156,6 @@ def test_version_err(make_request: _RequestMaker) -> None:
156156
make_request("get", "http://python.org/", version="1.c")
157157

158158

159-
def test_keep_alive(make_request: _RequestMaker) -> None:
160-
req = make_request("get", "http://python.org/", version=(0, 9))
161-
assert not req.keep_alive()
162-
163-
req = make_request("get", "http://python.org/", version=(1, 0))
164-
assert not req.keep_alive()
165-
166-
req = make_request(
167-
"get",
168-
"http://python.org/",
169-
version=(1, 0),
170-
headers={"connection": "keep-alive"},
171-
)
172-
assert req.keep_alive()
173-
174-
req = make_request("get", "http://python.org/", version=(1, 1))
175-
assert req.keep_alive()
176-
177-
req = make_request(
178-
"get", "http://python.org/", version=(1, 1), headers={"connection": "close"}
179-
)
180-
assert not req.keep_alive()
181-
182-
183159
def test_host_port_default_http(make_request: _RequestMaker) -> None:
184160
req = make_request("get", "http://python.org/")
185161
assert req.host == "python.org"
@@ -624,25 +600,31 @@ async def test_connection_header(
624600
loop: asyncio.AbstractEventLoop, conn: mock.Mock
625601
) -> None:
626602
req = ClientRequest("get", URL("http://python.org"), loop=loop)
627-
with mock.patch.object(req, "keep_alive") as m:
628-
req.headers.clear()
603+
req.headers.clear()
604+
605+
req.version = HttpVersion11
606+
req.headers.clear()
607+
with mock.patch.object(conn._connector, "force_close", False):
608+
await req.send(conn)
609+
assert req.headers.get("CONNECTION") is None
629610

630-
m.return_value = True
631-
req.version = HttpVersion(1, 1)
632-
req.headers.clear()
611+
req.version = HttpVersion10
612+
req.headers.clear()
613+
with mock.patch.object(conn._connector, "force_close", False):
633614
await req.send(conn)
634-
assert req.headers.get("CONNECTION") is None
615+
assert req.headers.get("CONNECTION") == "keep-alive"
635616

636-
req.version = HttpVersion(1, 0)
637-
req.headers.clear()
617+
req.version = HttpVersion11
618+
req.headers.clear()
619+
with mock.patch.object(conn._connector, "force_close", True):
638620
await req.send(conn)
639-
assert req.headers.get("CONNECTION") == "keep-alive"
621+
assert req.headers.get("CONNECTION") == "close"
640622

641-
m.return_value = False
642-
req.version = HttpVersion(1, 1)
643-
req.headers.clear()
623+
req.version = HttpVersion10
624+
req.headers.clear()
625+
with mock.patch.object(conn._connector, "force_close", True):
644626
await req.send(conn)
645-
assert req.headers.get("CONNECTION") == "close"
627+
assert not req.headers.get("CONNECTION")
646628

647629

648630
async def test_no_content_length(

tests/test_web_functional.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,6 @@ async def handler(request: web.Request) -> web.Response:
700700
resp.release()
701701

702702

703-
@pytest.mark.xfail
704703
async def test_http10_keep_alive_default(aiohttp_client: AiohttpClient) -> None:
705704
async def handler(request: web.Request) -> web.Response:
706705
return web.Response()

0 commit comments

Comments
 (0)