Skip to content

Commit 8c3e144

Browse files
committed
Implement GSSAPI authentication
Most commonly used with Kerberos. Closes: #769
1 parent c2c8d20 commit 8c3e144

7 files changed

+89
-23
lines changed

asyncpg/connect_utils.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def parse(cls, sslmode):
5656
'direct_tls',
5757
'server_settings',
5858
'target_session_attrs',
59+
'krbsrvname',
5960
])
6061

6162

@@ -261,7 +262,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
261262
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
262263
password, passfile, database, ssl,
263264
direct_tls, server_settings,
264-
target_session_attrs):
265+
target_session_attrs, krbsrvname):
265266
# `auth_hosts` is the version of host information for the purposes
266267
# of reading the pgpass file.
267268
auth_hosts = None
@@ -383,6 +384,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
383384
if target_session_attrs is None:
384385
target_session_attrs = dsn_target_session_attrs
385386

387+
if 'krbsrvname' in query:
388+
val = query.pop('krbsrvname')
389+
if krbsrvname is None:
390+
krbsrvname = val
391+
386392
if query:
387393
if server_settings is None:
388394
server_settings = query
@@ -654,7 +660,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
654660
user=user, password=password, database=database, ssl=ssl,
655661
sslmode=sslmode, direct_tls=direct_tls,
656662
server_settings=server_settings,
657-
target_session_attrs=target_session_attrs)
663+
target_session_attrs=target_session_attrs,
664+
krbsrvname=krbsrvname)
658665

659666
return addrs, params
660667

@@ -665,7 +672,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
665672
max_cached_statement_lifetime,
666673
max_cacheable_statement_size,
667674
ssl, direct_tls, server_settings,
668-
target_session_attrs):
675+
target_session_attrs, krbsrvname):
669676
local_vars = locals()
670677
for var_name in {'max_cacheable_statement_size',
671678
'max_cached_statement_lifetime',
@@ -694,7 +701,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
694701
password=password, passfile=passfile, ssl=ssl,
695702
direct_tls=direct_tls, database=database,
696703
server_settings=server_settings,
697-
target_session_attrs=target_session_attrs)
704+
target_session_attrs=target_session_attrs,
705+
krbsrvname=krbsrvname)
698706

699707
config = _ClientConfiguration(
700708
command_timeout=command_timeout,

asyncpg/connection.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -2007,7 +2007,8 @@ async def connect(dsn=None, *,
20072007
connection_class=Connection,
20082008
record_class=protocol.Record,
20092009
server_settings=None,
2010-
target_session_attrs=None):
2010+
target_session_attrs=None,
2011+
krbsrvname=None):
20112012
r"""A coroutine to establish a connection to a PostgreSQL server.
20122013
20132014
The connection parameters may be specified either as a connection
@@ -2235,6 +2236,10 @@ async def connect(dsn=None, *,
22352236
or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
22362237
or ``"any"`` if neither is specified.
22372238
2239+
:param str krbsrvname:
2240+
Kerberos service name to use when authenticating with GSSAPI. This
2241+
must match the server configuration. Defaults to 'postgres'.
2242+
22382243
:return: A :class:`~asyncpg.connection.Connection` instance.
22392244
22402245
Example:
@@ -2344,7 +2349,8 @@ async def connect(dsn=None, *,
23442349
statement_cache_size=statement_cache_size,
23452350
max_cached_statement_lifetime=max_cached_statement_lifetime,
23462351
max_cacheable_statement_size=max_cacheable_statement_size,
2347-
target_session_attrs=target_session_attrs
2352+
target_session_attrs=target_session_attrs,
2353+
krbsrvname=krbsrvname,
23482354
)
23492355

23502356

asyncpg/protocol/coreproto.pxd

+5-10
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,6 @@ cdef enum AuthenticationMessage:
5151
AUTH_SASL_FINAL = 12
5252

5353

54-
AUTH_METHOD_NAME = {
55-
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
56-
AUTH_REQUIRED_PASSWORD: 'password',
57-
AUTH_REQUIRED_PASSWORDMD5: 'md5',
58-
AUTH_REQUIRED_GSS: 'gss',
59-
AUTH_REQUIRED_SASL: 'scram-sha-256',
60-
AUTH_REQUIRED_SSPI: 'sspi',
61-
}
62-
63-
6454
cdef enum ResultType:
6555
RESULT_OK = 1
6656
RESULT_FAILED = 2
@@ -96,10 +86,13 @@ cdef class CoreProtocol:
9686

9787
object transport
9888

89+
object address
9990
# Instance of _ConnectionParameters
10091
object con_params
10192
# Instance of SCRAMAuthentication
10293
SCRAMAuthentication scram
94+
# Instance of gssapi.SecurityContext
95+
object gss_ctx
10396

10497
readonly int32_t backend_pid
10598
readonly int32_t backend_secret
@@ -145,6 +138,8 @@ cdef class CoreProtocol:
145138
cdef _auth_password_message_md5(self, bytes salt)
146139
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
147140
cdef _auth_password_message_sasl_continue(self, bytes server_response)
141+
cdef _auth_gss_init(self)
142+
cdef _auth_gss_step(self, bytes server_response)
148143

149144
cdef _write(self, buf)
150145
cdef _writelines(self, list buffers)

asyncpg/protocol/coreproto.pyx

+59-3
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,26 @@
66

77

88
import hashlib
9+
import socket
910

1011

1112
include "scram.pyx"
1213

1314

15+
cdef dict AUTH_METHOD_NAME = {
16+
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
17+
AUTH_REQUIRED_PASSWORD: 'password',
18+
AUTH_REQUIRED_PASSWORDMD5: 'md5',
19+
AUTH_REQUIRED_GSS: 'gss',
20+
AUTH_REQUIRED_SASL: 'scram-sha-256',
21+
AUTH_REQUIRED_SSPI: 'sspi',
22+
}
23+
24+
1425
cdef class CoreProtocol:
1526

16-
def __init__(self, con_params):
27+
def __init__(self, addr, con_params):
28+
self.address = addr
1729
# type of `con_params` is `_ConnectionParameters`
1830
self.buffer = ReadBuffer()
1931
self.user = con_params.user
@@ -26,6 +38,8 @@ cdef class CoreProtocol:
2638
self.encoding = 'utf-8'
2739
# type of `scram` is `SCRAMAuthentcation`
2840
self.scram = None
41+
# type of `gss_ctx` is `gssapi.SecurityContext`
42+
self.gss_ctx = None
2943

3044
self._reset_result()
3145

@@ -619,9 +633,17 @@ cdef class CoreProtocol:
619633
'could not verify server signature for '
620634
'SCRAM authentciation: scram-sha-256',
621635
)
636+
self.scram = None
637+
638+
elif status == AUTH_REQUIRED_GSS:
639+
self._auth_gss_init()
640+
self.auth_msg = self._auth_gss_step(None)
641+
642+
elif status == AUTH_REQUIRED_GSS_CONTINUE:
643+
server_response = self.buffer.consume_message()
644+
self.auth_msg = self._auth_gss_step(server_response)
622645

623646
elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
624-
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
625647
AUTH_REQUIRED_SSPI):
626648
self.result_type = RESULT_FAILED
627649
self.result = apg_exc.InterfaceError(
@@ -634,7 +656,8 @@ cdef class CoreProtocol:
634656
'unsupported authentication method requested by the '
635657
'server: {}'.format(status))
636658

637-
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
659+
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
660+
AUTH_REQUIRED_GSS_CONTINUE]:
638661
self.buffer.discard_message()
639662

640663
cdef _auth_password_message_cleartext(self):
@@ -691,6 +714,39 @@ cdef class CoreProtocol:
691714

692715
return msg
693716

717+
cdef _auth_gss_init(self):
718+
try:
719+
import gssapi
720+
except ModuleNotFoundError:
721+
raise RuntimeError(
722+
'gssapi module not found; please install asyncpg[gss] to use '
723+
'asyncpg with Kerberos or GSSAPI authentication'
724+
) from None
725+
726+
service_name = self.con_params.krbsrvname or 'postgres'
727+
# find the canonical name of the server host
728+
if isinstance(self.address, str):
729+
host = socket.gethostname()
730+
else:
731+
host = self.address[0]
732+
host_cname = socket.gethostbyname_ex(host)[0].rstrip('.')
733+
gss_name = gssapi.Name(f'{service_name}/{host_cname}')
734+
self.gss_ctx = gssapi.SecurityContext(name=gss_name, usage='initiate')
735+
736+
cdef _auth_gss_step(self, bytes server_response):
737+
cdef:
738+
WriteBuffer msg
739+
740+
token = self.gss_ctx.step(server_response)
741+
if not token:
742+
self.gss_ctx = None
743+
return None
744+
msg = WriteBuffer.new_message(b'p')
745+
msg.write_bytes(token)
746+
msg.end_message()
747+
748+
return msg
749+
694750
cdef _parse_msg_ready_for_query(self):
695751
cdef char status = self.buffer.read_byte()
696752

asyncpg/protocol/protocol.pxd

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol):
3131

3232
cdef:
3333
object loop
34-
object address
3534
ConnectionSettings settings
3635
object cancel_sent_waiter
3736
object cancel_waiter

asyncpg/protocol/protocol.pyx

+2-3
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,15 @@ NO_TIMEOUT = object()
7575
cdef class BaseProtocol(CoreProtocol):
7676
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
7777
# type of `con_params` is `_ConnectionParameters`
78-
CoreProtocol.__init__(self, con_params)
78+
CoreProtocol.__init__(self, addr, con_params)
7979

8080
self.loop = loop
8181
self.transport = None
8282
self.waiter = connected_fut
8383
self.cancel_waiter = None
8484
self.cancel_sent_waiter = None
8585

86-
self.address = addr
87-
self.settings = ConnectionSettings((self.address, con_params.database))
86+
self.settings = ConnectionSettings((addr, con_params.database))
8887
self.record_class = record_class
8988

9089
self.statement = None

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ dependencies = [
3535
github = "https://github.com/MagicStack/asyncpg"
3636

3737
[project.optional-dependencies]
38+
gss = [
39+
'gssapi',
40+
]
3841
test = [
3942
'flake8~=6.1',
4043
'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"',

0 commit comments

Comments
 (0)