Skip to content

Commit 1325c3c

Browse files
Configurable ssl_ctx on NetworkBackend (#132)
1 parent 84cc6c7 commit 1325c3c

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

src/ahttpx/_network.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,12 @@ async def __aexit__(
8181

8282

8383
class NetworkBackend:
84-
def __init__(self):
85-
self._ssl_context = ssl.create_default_context(cafile=certifi.where())
84+
def __init__(self, ssl_ctx: ssl.SSLContext | None = None):
85+
self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx
86+
87+
def create_default_context(self) -> ssl.SSLContext:
88+
import certifi
89+
return ssl.create_default_context(cafile=certifi.where())
8690

8791
async def connect(self, host: str, port: int) -> NetworkStream:
8892
"""
@@ -98,7 +102,7 @@ async def connect_tls(self, host: str, port: int, hostname: str = '') -> Network
98102
"""
99103
address = f"{host}:{port}"
100104
reader, writer = await asyncio.open_connection(host, port)
101-
await writer.start_tls(self._ssl_context, server_hostname=hostname)
105+
await writer.start_tls(self._ssl_ctx, server_hostname=hostname)
102106
return NetworkStream(reader, writer, address=address)
103107

104108
async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer:

src/httpx/_network.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import types
1010
import typing
1111

12-
import certifi
13-
1412
from ._streams import Stream
1513

1614

@@ -193,8 +191,12 @@ def _handler(self, stream):
193191

194192

195193
class NetworkBackend:
196-
def __init__(self):
197-
self._ssl_context = ssl.create_default_context(cafile=certifi.where())
194+
def __init__(self, ssl_ctx: ssl.SSLContext | None = None):
195+
self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx
196+
197+
def create_default_context(self) -> ssl.SSLContext:
198+
import certifi
199+
return ssl.create_default_context(cafile=certifi.where())
198200

199201
def connect(self, host: str, port: int) -> NetworkStream:
200202
"""
@@ -213,7 +215,7 @@ def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream
213215
hostname = hostname or host
214216
timeout = get_current_timeout()
215217
sock = socket.create_connection(address, timeout=timeout)
216-
sock = self._ssl_context.wrap_socket(sock, server_hostname=hostname)
218+
sock = self._ssl_ctx.wrap_socket(sock, server_hostname=hostname)
217219
return NetworkStream(sock, address)
218220

219221
def listen(self, host: str, port: int) -> NetworkListener:

0 commit comments

Comments
 (0)