Skip to content

Commit eea8a37

Browse files
sleepyStickNoahStappShaneHarvey
authored
PYTHON-3636 AsyncMongoClient should perform SRV resolution lazily (#2191)
Co-authored-by: Noah Stapp <[email protected]> Co-authored-by: Shane Harvey <[email protected]>
1 parent 38ceda4 commit eea8a37

31 files changed

+1627
-850
lines changed

doc/changelog.rst

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ PyMongo 4.12 brings a number of changes including:
99
- Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to
1010
:class:`~pymongo.encryption_options.AutoEncryptionOpts`.
1111
- Support for $lookup in CSFLE and QE supported on MongoDB 8.1+.
12+
- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation.
13+
To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected.
1214
- Added index hinting support to the
1315
:meth:`~pymongo.asynchronous.collection.AsyncCollection.distinct` and
1416
:meth:`~pymongo.collection.Collection.distinct` commands.

pymongo/asynchronous/encryption.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
from pymongo.results import BulkWriteResult, DeleteResult
8888
from pymongo.ssl_support import get_ssl_context
8989
from pymongo.typings import _DocumentType, _DocumentTypeArg
90-
from pymongo.uri_parser import parse_host
90+
from pymongo.uri_parser_shared import parse_host
9191
from pymongo.write_concern import WriteConcern
9292

9393
if TYPE_CHECKING:

pymongo/asynchronous/mongo_client.py

+190-67
Large diffs are not rendered by default.

pymongo/asynchronous/monitor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from pymongo import common, periodic_executor
2727
from pymongo._csot import MovingMinimum
28+
from pymongo.asynchronous.srv_resolver import _SrvResolver
2829
from pymongo.errors import NetworkTimeout, _OperationCancelled
2930
from pymongo.hello import Hello
3031
from pymongo.lock import _async_create_lock
@@ -33,7 +34,6 @@
3334
from pymongo.pool_options import _is_faas
3435
from pymongo.read_preferences import MovingAverage
3536
from pymongo.server_description import ServerDescription
36-
from pymongo.srv_resolver import _SrvResolver
3737

3838
if TYPE_CHECKING:
3939
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
@@ -395,7 +395,7 @@ async def _run(self) -> None:
395395
# Don't poll right after creation, wait 60 seconds first
396396
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
397397
return
398-
seedlist = self._get_seedlist()
398+
seedlist = await self._get_seedlist()
399399
if seedlist:
400400
self._seedlist = seedlist
401401
try:
@@ -404,7 +404,7 @@ async def _run(self) -> None:
404404
# Topology was garbage-collected.
405405
await self.close()
406406

407-
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
407+
async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
408408
"""Poll SRV records for a seedlist.
409409
410410
Returns a list of ServerDescriptions.
@@ -415,7 +415,7 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
415415
self._settings.pool_options.connect_timeout,
416416
self._settings.srv_service_name,
417417
)
418-
seedlist, ttl = resolver.get_hosts_and_min_ttl()
418+
seedlist, ttl = await resolver.get_hosts_and_min_ttl()
419419
if len(seedlist) == 0:
420420
# As per the spec: this should be treated as a failure.
421421
raise Exception

pymongo/asynchronous/srv_resolver.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright 2019-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you
4+
# may not use this file except in compliance with the License. You
5+
# may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12+
# implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
15+
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
16+
from __future__ import annotations
17+
18+
import ipaddress
19+
import random
20+
from typing import TYPE_CHECKING, Any, Optional, Union
21+
22+
from pymongo.common import CONNECT_TIMEOUT
23+
from pymongo.errors import ConfigurationError
24+
25+
if TYPE_CHECKING:
26+
from dns import resolver
27+
28+
_IS_SYNC = False
29+
30+
31+
def _have_dnspython() -> bool:
32+
try:
33+
import dns # noqa: F401
34+
35+
return True
36+
except ImportError:
37+
return False
38+
39+
40+
# dnspython can return bytes or str from various parts
41+
# of its API depending on version. We always want str.
42+
def maybe_decode(text: Union[str, bytes]) -> str:
43+
if isinstance(text, bytes):
44+
return text.decode()
45+
return text
46+
47+
48+
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
49+
async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
50+
if _IS_SYNC:
51+
from dns import resolver
52+
53+
if hasattr(resolver, "resolve"):
54+
# dnspython >= 2
55+
return resolver.resolve(*args, **kwargs)
56+
# dnspython 1.X
57+
return resolver.query(*args, **kwargs)
58+
else:
59+
from dns import asyncresolver
60+
61+
if hasattr(asyncresolver, "resolve"):
62+
# dnspython >= 2
63+
return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
64+
raise ConfigurationError(
65+
"Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections."
66+
)
67+
68+
69+
_INVALID_HOST_MSG = (
70+
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
71+
"Did you mean to use 'mongodb://'?"
72+
)
73+
74+
75+
class _SrvResolver:
76+
def __init__(
77+
self,
78+
fqdn: str,
79+
connect_timeout: Optional[float],
80+
srv_service_name: str,
81+
srv_max_hosts: int = 0,
82+
):
83+
self.__fqdn = fqdn
84+
self.__srv = srv_service_name
85+
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
86+
self.__srv_max_hosts = srv_max_hosts or 0
87+
# Validate the fully qualified domain name.
88+
try:
89+
ipaddress.ip_address(fqdn)
90+
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
91+
except ValueError:
92+
pass
93+
94+
try:
95+
self.__plist = self.__fqdn.split(".")[1:]
96+
except Exception:
97+
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
98+
self.__slen = len(self.__plist)
99+
if self.__slen < 2:
100+
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
101+
102+
async def get_options(self) -> Optional[str]:
103+
from dns import resolver
104+
105+
try:
106+
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
107+
except (resolver.NoAnswer, resolver.NXDOMAIN):
108+
# No TXT records
109+
return None
110+
except Exception as exc:
111+
raise ConfigurationError(str(exc)) from None
112+
if len(results) > 1:
113+
raise ConfigurationError("Only one TXT record is supported")
114+
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined]
115+
116+
async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
117+
try:
118+
results = await _resolve(
119+
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
120+
)
121+
except Exception as exc:
122+
if not encapsulate_errors:
123+
# Raise the original error.
124+
raise
125+
# Else, raise all errors as ConfigurationError.
126+
raise ConfigurationError(str(exc)) from None
127+
return results
128+
129+
async def _get_srv_response_and_hosts(
130+
self, encapsulate_errors: bool
131+
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
132+
results = await self._resolve_uri(encapsulate_errors)
133+
134+
# Construct address tuples
135+
nodes = [
136+
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
137+
for res in results
138+
]
139+
140+
# Validate hosts
141+
for node in nodes:
142+
try:
143+
nlist = node[0].lower().split(".")[1:][-self.__slen :]
144+
except Exception:
145+
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
146+
if self.__plist != nlist:
147+
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
148+
if self.__srv_max_hosts:
149+
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
150+
return results, nodes
151+
152+
async def get_hosts(self) -> list[tuple[str, Any]]:
153+
_, nodes = await self._get_srv_response_and_hosts(True)
154+
return nodes
155+
156+
async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
157+
results, nodes = await self._get_srv_response_and_hosts(False)
158+
rrset = results.rrset
159+
ttl = rrset.ttl if rrset else 0
160+
return nodes, ttl

0 commit comments

Comments
 (0)