google/cloud/sql/connector/resolver.py (43 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dns.asyncresolver from google.cloud.sql.connector.connection_name import ( _parse_connection_name_with_domain_name, ) from google.cloud.sql.connector.connection_name import _is_valid_domain from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError class DefaultResolver: """DefaultResolver simply validates and parses instance connection name.""" async def resolve(self, connection_name: str) -> ConnectionName: return _parse_connection_name(connection_name) class DnsResolver(dns.asyncresolver.Resolver): """ DnsResolver resolves domain names into instance connection names using TXT records in DNS. """ async def resolve(self, dns: str) -> ConnectionName: # type: ignore try: conn_name = _parse_connection_name(dns) except ValueError: # The connection name was not project:region:instance format. # Check if connection name is a valid DNS domain name if _is_valid_domain(dns): # Attempt to query a TXT record to get connection name. conn_name = await self.query_dns(dns) else: raise ValueError( "Arg `instance_connection_string` must have " "format: PROJECT:REGION:INSTANCE or be a valid DNS domain " f"name, got {dns}." ) return conn_name async def query_dns(self, dns: str) -> ConnectionName: try: # Attempt to query the TXT records. records = await super().resolve(dns, "TXT", raise_on_no_answer=True) # Sort the TXT record values alphabetically, strip quotes as record # values can be returned as raw strings rdata = [record.to_text().strip('"') for record in records] rdata.sort() # Attempt to parse records, returning the first valid record. for record in rdata: try: conn_name = _parse_connection_name_with_domain_name(record, dns) return conn_name except Exception: continue # If all records failed to parse, throw error raise DnsResolutionError( f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`" ) # Don't override above DnsResolutionError except DnsResolutionError: raise except Exception as e: raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e