Skip to content

Commit 80ca379

Browse files
authored
refactor: Use new ConnectSettings.DnsNames field determine the DNS Name of the instance. (#1242)
The Cloud SQL Instance ConnectSettings added a new field `dns_names` which contains a list of valid DNS names for an instance. The Python Connector will use these DNS names, falling back to the old `dns_name` field if `dns_names` is not populated. Other connectors use this DNS name for hostname validation for the instance's TLS server certificate. However, the python connector does not perform hostname validation due to limitations of python's TLS library. See also: GoogleCloudPlatform/cloud-sql-go-connector#954
1 parent 15934bd commit 80ca379

File tree

4 files changed

+55
-5
lines changed

4 files changed

+55
-5
lines changed

Diff for: .gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ venv
55
.python-version
66
cloud_sql_python_connector.egg-info/
77
dist/
8+
.idea
9+
.coverage
10+
sponge_log.xml

Diff for: google/cloud/sql/connector/client.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,23 @@ async def _get_metadata(
156156
# resolve dnsName into IP address for PSC
157157
# Note that we have to check for PSC enablement also because CAS
158158
# instances also set the dnsName field.
159-
# Remove trailing period from DNS name. Required for SSL in Python
160-
dns_name = ret_dict.get("dnsName", "").rstrip(".")
161-
if dns_name and ret_dict.get("pscEnabled"):
162-
ip_addresses["PSC"] = dns_name
159+
if ret_dict.get("pscEnabled"):
160+
# Find PSC instance DNS name in the dns_names field
161+
psc_dns_names = [
162+
d["name"]
163+
for d in ret_dict.get("dnsNames", [])
164+
if d["connectionType"] == "PRIVATE_SERVICE_CONNECT"
165+
and d["dnsScope"] == "INSTANCE"
166+
]
167+
dns_name = psc_dns_names[0] if psc_dns_names else None
168+
169+
# Fall back do dns_name field if dns_names is not set
170+
if dns_name is None:
171+
dns_name = ret_dict.get("dnsName", None)
172+
173+
# Remove trailing period from DNS name. Required for SSL in Python
174+
if dns_name:
175+
ip_addresses["PSC"] = dns_name.rstrip(".")
163176

164177
return {
165178
"ip_addresses": ip_addresses,

Diff for: tests/unit/mocks.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def __init__(
225225
"PRIMARY": "127.0.0.1",
226226
"PRIVATE": "10.0.0.1",
227227
},
228+
legacy_dns_name: bool = False,
228229
cert_before: datetime = datetime.datetime.now(datetime.timezone.utc),
229230
cert_expiration: datetime = datetime.datetime.now(datetime.timezone.utc)
230231
+ datetime.timedelta(hours=1),
@@ -237,6 +238,7 @@ def __init__(
237238
self.psc_enabled = False
238239
self.cert_before = cert_before
239240
self.cert_expiration = cert_expiration
241+
self.legacy_dns_name = legacy_dns_name
240242
# create self signed CA cert
241243
self.server_ca, self.server_key = generate_cert(
242244
self.project, self.name, cert_before, cert_expiration
@@ -255,12 +257,22 @@ async def connect_settings(self, request: Any) -> web.Response:
255257
"instance": self.name,
256258
"expirationTime": str(self.cert_expiration),
257259
},
258-
"dnsName": "abcde.12345.us-central1.sql.goog",
259260
"pscEnabled": self.psc_enabled,
260261
"ipAddresses": ip_addrs,
261262
"region": self.region,
262263
"databaseVersion": self.db_version,
263264
}
265+
if self.legacy_dns_name:
266+
response["dnsName"] = "abcde.12345.us-central1.sql.goog"
267+
else:
268+
response["dnsNames"] = [
269+
{
270+
"name": "abcde.12345.us-central1.sql.goog",
271+
"connectionType": "PRIVATE_SERVICE_CONNECT",
272+
"dnsScope": "INSTANCE",
273+
}
274+
]
275+
264276
return web.Response(content_type="application/json", body=json.dumps(response))
265277

266278
async def generate_ephemeral(self, request: Any) -> web.Response:

Diff for: tests/unit/test_client.py

+22
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,28 @@ async def test_get_metadata_with_psc(fake_client: CloudSQLClient) -> None:
6565
assert isinstance(resp["server_ca_cert"], str)
6666

6767

68+
@pytest.mark.asyncio
69+
async def test_get_metadata_legacy_dns_with_psc(fake_client: CloudSQLClient) -> None:
70+
"""
71+
Test _get_metadata returns successfully with PSC IP type.
72+
"""
73+
# set PSC to enabled on test instance
74+
fake_client.instance.psc_enabled = True
75+
fake_client.instance.legacy_dns_name = True
76+
resp = await fake_client._get_metadata(
77+
"test-project",
78+
"test-region",
79+
"test-instance",
80+
)
81+
assert resp["database_version"] == "POSTGRES_15"
82+
assert resp["ip_addresses"] == {
83+
"PRIMARY": "127.0.0.1",
84+
"PRIVATE": "10.0.0.1",
85+
"PSC": "abcde.12345.us-central1.sql.goog",
86+
}
87+
assert isinstance(resp["server_ca_cert"], str)
88+
89+
6890
@pytest.mark.asyncio
6991
async def test_get_ephemeral(fake_client: CloudSQLClient) -> None:
7092
"""

0 commit comments

Comments
 (0)