Skip to content

Commit 2634402

Browse files
committed
Add support for the sslnegotiation parameter
Direct TLS connections are already supported via the `direct_tls` argument, however PostgreSQL 17 added native support for this via `sslnegotiation`, so recognize it in DSNs and the environment. I decided not to introduce the `sslnegotiation` connection constructor argument for now, `direct_tls` should continue to be used instead.
1 parent 8f2be4c commit 2634402

File tree

5 files changed

+114
-8
lines changed

5 files changed

+114
-8
lines changed

asyncpg/compat.py

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import annotations
88

9+
import enum
910
import pathlib
1011
import platform
1112
import typing
@@ -78,3 +79,10 @@ def markcoroutinefunction(c): # type: ignore
7879
from collections.abc import ( # noqa: F401
7980
Awaitable as Awaitable,
8081
)
82+
83+
if sys.version_info < (3, 11):
84+
class StrEnum(str, enum.Enum):
85+
__str__ = str.__str__
86+
__repr__ = enum.Enum.__repr__
87+
else:
88+
from enum import StrEnum as StrEnum # noqa: F401

asyncpg/connect_utils.py

+38-6
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def parse(cls, sslmode):
4545
return getattr(cls, sslmode.replace('-', '_'))
4646

4747

48+
class SSLNegotiation(compat.StrEnum):
49+
postgres = "postgres"
50+
direct = "direct"
51+
52+
4853
_ConnectionParameters = collections.namedtuple(
4954
'ConnectionParameters',
5055
[
@@ -53,7 +58,7 @@ def parse(cls, sslmode):
5358
'database',
5459
'ssl',
5560
'sslmode',
56-
'direct_tls',
61+
'ssl_negotiation',
5762
'server_settings',
5863
'target_session_attrs',
5964
'krbsrvname',
@@ -269,6 +274,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
269274
auth_hosts = None
270275
sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
271276
ssl_min_protocol_version = ssl_max_protocol_version = None
277+
sslnegotiation = None
272278

273279
if dsn:
274280
parsed = urllib.parse.urlparse(dsn)
@@ -362,6 +368,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
362368
if 'sslrootcert' in query:
363369
sslrootcert = query.pop('sslrootcert')
364370

371+
if 'sslnegotiation' in query:
372+
sslnegotiation = query.pop('sslnegotiation')
373+
365374
if 'sslcrl' in query:
366375
sslcrl = query.pop('sslcrl')
367376

@@ -503,13 +512,36 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
503512
if ssl is None and have_tcp_addrs:
504513
ssl = 'prefer'
505514

515+
if direct_tls is not None:
516+
sslneg = (
517+
SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
518+
)
519+
else:
520+
if sslnegotiation is None:
521+
sslnegotiation = os.environ.get("PGSSLNEGOTIATION")
522+
523+
if sslnegotiation is not None:
524+
try:
525+
sslneg = SSLNegotiation(sslnegotiation)
526+
except ValueError:
527+
modes = ', '.join(
528+
m.name.replace('_', '-')
529+
for m in SSLNegotiation
530+
)
531+
raise exceptions.ClientConfigurationError(
532+
f'`sslnegotiation` parameter must be one of: {modes}'
533+
) from None
534+
else:
535+
sslneg = SSLNegotiation.postgres
536+
506537
if isinstance(ssl, (str, SSLMode)):
507538
try:
508539
sslmode = SSLMode.parse(ssl)
509540
except AttributeError:
510541
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
511542
raise exceptions.ClientConfigurationError(
512-
'`sslmode` parameter must be one of: {}'.format(modes))
543+
'`sslmode` parameter must be one of: {}'.format(modes)
544+
) from None
513545

514546
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
515547
if sslmode < SSLMode.allow:
@@ -676,7 +708,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
676708

677709
params = _ConnectionParameters(
678710
user=user, password=password, database=database, ssl=ssl,
679-
sslmode=sslmode, direct_tls=direct_tls,
711+
sslmode=sslmode, ssl_negotiation=sslneg,
680712
server_settings=server_settings,
681713
target_session_attrs=target_session_attrs,
682714
krbsrvname=krbsrvname, gsslib=gsslib)
@@ -882,9 +914,9 @@ async def __connect_addr(
882914
# UNIX socket
883915
connector = loop.create_unix_connection(proto_factory, addr)
884916

885-
elif params.ssl and params.direct_tls:
886-
# if ssl and direct_tls are given, skip STARTTLS and perform direct
887-
# SSL connection
917+
elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
918+
# if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
919+
# direct SSL connection
888920
connector = loop.create_connection(
889921
proto_factory, *addr, ssl=params.ssl
890922
)

asyncpg/connection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2001,7 +2001,7 @@ async def connect(dsn=None, *,
20012001
max_cacheable_statement_size=1024 * 15,
20022002
command_timeout=None,
20032003
ssl=None,
2004-
direct_tls=False,
2004+
direct_tls=None,
20052005
connection_class=Connection,
20062006
record_class=protocol.Record,
20072007
server_settings=None,

pyproject.toml

+9
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,15 @@ exclude_lines = [
112112
show_missing = true
113113

114114
[tool.mypy]
115+
exclude = [
116+
"^.eggs",
117+
"^.github",
118+
"^.vscode",
119+
"^build",
120+
"^dist",
121+
"^docs",
122+
"^tests",
123+
]
115124
incremental = true
116125
strict = true
117126
implicit_reexport = true

tests/test_connect.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,58 @@ class TestConnectParams(tb.TestCase):
592592
'target_session_attrs': 'any'})
593593
},
594594

595+
{
596+
'name': 'params_ssl_negotiation_dsn',
597+
'env': {
598+
'PGSSLNEGOTIATION': 'postgres'
599+
},
600+
601+
'dsn': 'postgres://u:p@localhost/d?sslnegotiation=direct',
602+
603+
'result': ([('localhost', 5432)], {
604+
'user': 'u',
605+
'password': 'p',
606+
'database': 'd',
607+
'ssl_negotiation': 'direct',
608+
'target_session_attrs': 'any',
609+
})
610+
},
611+
612+
{
613+
'name': 'params_ssl_negotiation_env',
614+
'env': {
615+
'PGSSLNEGOTIATION': 'direct'
616+
},
617+
618+
'dsn': 'postgres://u:p@localhost/d',
619+
620+
'result': ([('localhost', 5432)], {
621+
'user': 'u',
622+
'password': 'p',
623+
'database': 'd',
624+
'ssl_negotiation': 'direct',
625+
'target_session_attrs': 'any',
626+
})
627+
},
628+
629+
{
630+
'name': 'params_ssl_negotiation_params',
631+
'env': {
632+
'PGSSLNEGOTIATION': 'direct'
633+
},
634+
635+
'dsn': 'postgres://u:p@localhost/d',
636+
'direct_tls': False,
637+
638+
'result': ([('localhost', 5432)], {
639+
'user': 'u',
640+
'password': 'p',
641+
'database': 'd',
642+
'ssl_negotiation': 'postgres',
643+
'target_session_attrs': 'any',
644+
})
645+
},
646+
595647
{
596648
'name': 'dsn_overrides_env_partially_ssl_prefer',
597649
'env': {
@@ -1067,6 +1119,7 @@ def run_testcase(self, testcase):
10671119
passfile = testcase.get('passfile')
10681120
database = testcase.get('database')
10691121
sslmode = testcase.get('ssl')
1122+
direct_tls = testcase.get('direct_tls')
10701123
server_settings = testcase.get('server_settings')
10711124
target_session_attrs = testcase.get('target_session_attrs')
10721125
krbsrvname = testcase.get('krbsrvname')
@@ -1093,7 +1146,7 @@ def run_testcase(self, testcase):
10931146
addrs, params = connect_utils._parse_connect_dsn_and_args(
10941147
dsn=dsn, host=host, port=port, user=user, password=password,
10951148
passfile=passfile, database=database, ssl=sslmode,
1096-
direct_tls=False,
1149+
direct_tls=direct_tls,
10971150
server_settings=server_settings,
10981151
target_session_attrs=target_session_attrs,
10991152
krbsrvname=krbsrvname, gsslib=gsslib)
@@ -1118,6 +1171,10 @@ def run_testcase(self, testcase):
11181171
# Avoid the hassle of specifying direct_tls
11191172
# unless explicitly tested for
11201173
params.pop('direct_tls', False)
1174+
if 'ssl_negotiation' not in expected[1]:
1175+
# Avoid the hassle of specifying sslnegotiation
1176+
# unless explicitly tested for
1177+
params.pop('ssl_negotiation', False)
11211178
if 'gsslib' not in expected[1]:
11221179
# Avoid the hassle of specifying gsslib
11231180
# unless explicitly tested for

0 commit comments

Comments
 (0)