google/cloud/sql/connector/resolver.py (43 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dns.asyncresolver
from google.cloud.sql.connector.connection_name import (
_parse_connection_name_with_domain_name,
)
from google.cloud.sql.connector.connection_name import _is_valid_domain
from google.cloud.sql.connector.connection_name import _parse_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import DnsResolutionError
class DefaultResolver:
"""DefaultResolver simply validates and parses instance connection name."""
async def resolve(self, connection_name: str) -> ConnectionName:
return _parse_connection_name(connection_name)
class DnsResolver(dns.asyncresolver.Resolver):
"""
DnsResolver resolves domain names into instance connection names using
TXT records in DNS.
"""
async def resolve(self, dns: str) -> ConnectionName: # type: ignore
try:
conn_name = _parse_connection_name(dns)
except ValueError:
# The connection name was not project:region:instance format.
# Check if connection name is a valid DNS domain name
if _is_valid_domain(dns):
# Attempt to query a TXT record to get connection name.
conn_name = await self.query_dns(dns)
else:
raise ValueError(
"Arg `instance_connection_string` must have "
"format: PROJECT:REGION:INSTANCE or be a valid DNS domain "
f"name, got {dns}."
)
return conn_name
async def query_dns(self, dns: str) -> ConnectionName:
try:
# Attempt to query the TXT records.
records = await super().resolve(dns, "TXT", raise_on_no_answer=True)
# Sort the TXT record values alphabetically, strip quotes as record
# values can be returned as raw strings
rdata = [record.to_text().strip('"') for record in records]
rdata.sort()
# Attempt to parse records, returning the first valid record.
for record in rdata:
try:
conn_name = _parse_connection_name_with_domain_name(record, dns)
return conn_name
except Exception:
continue
# If all records failed to parse, throw error
raise DnsResolutionError(
f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`"
)
# Don't override above DnsResolutionError
except DnsResolutionError:
raise
except Exception as e:
raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e