google/cloud/sql/connector/client.py (191 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import asyncio import datetime import logging from typing import Any, Optional, TYPE_CHECKING import aiohttp from cryptography.hazmat.backends import default_backend from cryptography.x509 import load_pem_x509_certificate from google.auth.credentials import TokenState from google.auth.transport import requests from google.cloud.sql.connector.connection_info import ConnectionInfo from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.refresh_utils import _downscope_credentials from google.cloud.sql.connector.refresh_utils import retry_50x from google.cloud.sql.connector.version import __version__ as version if TYPE_CHECKING: from google.auth.credentials import Credentials USER_AGENT: str = f"cloud-sql-python-connector/{version}" API_VERSION: str = "v1beta4" DEFAULT_SERVICE_ENDPOINT: str = "https://sqladmin.googleapis.com" logger = logging.getLogger(name=__name__) def _format_user_agent(driver: Optional[str], custom: Optional[str]) -> str: agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT if custom and isinstance(custom, str): agent = f"{agent} {custom}" return agent class CloudSQLClient: def __init__( self, sqladmin_api_endpoint: Optional[str], quota_project: Optional[str], credentials: Credentials, client: Optional[aiohttp.ClientSession] = None, driver: Optional[str] = None, user_agent: Optional[str] = None, ) -> None: """Establishes the client to be used for Cloud SQL Admin API requests. Args: sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL Admin API endpoints. quota_project (str): The Project ID for an existing Google Cloud project. The project specified is used for quota and billing purposes. credentials (google.auth.credentials.Credentials): A credentials object created from the google-auth Python library. Must have the Cloud SQL Admin scopes. For more info check out https://google-auth.readthedocs.io/en/latest/. client (aiohttp.ClientSession): Async client used to make requests to Cloud SQL Admin APIs. Optional, defaults to None and creates new client. driver (str): Database driver to be used by the client. """ user_agent = _format_user_agent(driver, user_agent) headers = { "x-goog-api-client": user_agent, "User-Agent": user_agent, "Content-Type": "application/json", } if quota_project: headers["x-goog-user-project"] = quota_project self._client = client if client else aiohttp.ClientSession(headers=headers) self._credentials = credentials if sqladmin_api_endpoint is None: self._sqladmin_api_endpoint = DEFAULT_SERVICE_ENDPOINT else: self._sqladmin_api_endpoint = sqladmin_api_endpoint self._user_agent = user_agent async def _get_metadata( self, project: str, region: str, instance: str, ) -> dict[str, Any]: """Requests metadata from the Cloud SQL Instance and returns a dictionary containing the IP addresses and certificate authority of the Cloud SQL Instance. Args: project (str): A string representing the name of the project. region (str): A string representing the name of the region. instance (str): A string representing the name of the instance. Returns: A dictionary containing a dictionary of all IP addresses and their type and a string representing the certificate authority. Raises: ValueError: Provided region does not match the region of the Cloud SQL instance. """ headers = { "Authorization": f"Bearer {self._credentials.token}", } url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}/connectSettings" resp = await self._client.get(url, headers=headers) if resp.status >= 500: resp = await retry_50x(self._client.get, url, headers=headers) # try to get response json for better error message try: ret_dict = await resp.json() if resp.status >= 400: # if detailed error message is in json response, use as error message message = ret_dict.get("error", {}).get("message") if message: resp.reason = message # skip, raise_for_status will catch all errors in finally block except Exception: pass finally: resp.raise_for_status() if ret_dict["region"] != region: raise ValueError( f'[{project}:{region}:{instance}]: Provided region was mismatched - got region {region}, expected {ret_dict["region"]}.' ) ip_addresses = ( {ip["type"]: ip["ipAddress"] for ip in ret_dict["ipAddresses"]} if "ipAddresses" in ret_dict else {} ) # resolve dnsName into IP address for PSC # Note that we have to check for PSC enablement also because CAS # instances also set the dnsName field. if ret_dict.get("pscEnabled"): # Find PSC instance DNS name in the dns_names field psc_dns_names = [ d["name"] for d in ret_dict.get("dnsNames", []) if d["connectionType"] == "PRIVATE_SERVICE_CONNECT" and d["dnsScope"] == "INSTANCE" ] dns_name = psc_dns_names[0] if psc_dns_names else None # Fall back do dns_name field if dns_names is not set if dns_name is None: dns_name = ret_dict.get("dnsName", None) # Remove trailing period from DNS name. Required for SSL in Python if dns_name: ip_addresses["PSC"] = dns_name.rstrip(".") return { "ip_addresses": ip_addresses, "server_ca_cert": ret_dict["serverCaCert"]["cert"], "database_version": ret_dict["databaseVersion"], } async def _get_ephemeral( self, project: str, instance: str, pub_key: str, enable_iam_auth: bool = False, ) -> tuple[str, datetime.datetime]: """Asynchronously requests an ephemeral certificate from the Cloud SQL Instance. Args: project (str): A string representing the name of the project. instance (str): string representing the name of the instance. pub_key (str): A string representing PEM-encoded RSA public key. enable_iam_auth (bool): Enables automatic IAM database authentication for Postgres or MySQL instances. Returns: A tuple containing an ephemeral certificate from the Cloud SQL instance as well as a datetime object representing the expiration time of the certificate. """ headers = { "Authorization": f"Bearer {self._credentials.token}", } url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}:generateEphemeralCert" data = {"public_key": pub_key} if enable_iam_auth: # down-scope credentials with only IAM login scope (refreshes them too) login_creds = _downscope_credentials(self._credentials) data["access_token"] = login_creds.token resp = await self._client.post(url, headers=headers, json=data) if resp.status >= 500: resp = await retry_50x(self._client.post, url, headers=headers, json=data) # try to get response json for better error message try: ret_dict = await resp.json() if resp.status >= 400: # if detailed error message is in json response, use as error message message = ret_dict.get("error", {}).get("message") if message: resp.reason = message # skip, raise_for_status will catch all errors in finally block except Exception: pass finally: resp.raise_for_status() ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"] # decode cert to read expiration x509 = load_pem_x509_certificate( ephemeral_cert.encode("UTF-8"), default_backend() ) expiration = x509.not_valid_after_utc # for IAM authentication OAuth2 token is embedded in cert so it # must still be valid for successful connection if enable_iam_auth: token_expiration: datetime.datetime = login_creds.expiry # google.auth library strips timezone info for backwards compatibality # reasons with Python 2. Add it back to allow timezone aware datetimes. # Ref: https://github.com/googleapis/google-auth-library-python/blob/49a5ff7411a2ae4d32a7d11700f9f961c55406a9/google/auth/_helpers.py#L93-L99 token_expiration = token_expiration.replace(tzinfo=datetime.timezone.utc) if expiration > token_expiration: expiration = token_expiration return ephemeral_cert, expiration async def get_connection_info( self, conn_name: ConnectionName, keys: asyncio.Future, enable_iam_auth: bool, ) -> ConnectionInfo: """Immediately performs a full refresh operation using the Cloud SQL Admin API. Args: conn_name (ConnectionName): The Cloud SQL instance's connection name. keys (asyncio.Future): A future to the client's public-private key pair. enable_iam_auth (bool): Whether an automatic IAM database authentication connection is being requested (Postgres and MySQL). Returns: ConnectionInfo: All the information required to connect securely to the Cloud SQL instance. Raises: AutoIAMAuthNotSupported: Database engine does not support automatic IAM authentication. """ priv_key, pub_key = await keys # before making Cloud SQL Admin API calls, refresh creds if required if not self._credentials.token_state == TokenState.FRESH: self._credentials.refresh(requests.Request()) metadata_task = asyncio.create_task( self._get_metadata( conn_name.project, conn_name.region, conn_name.instance_name, ) ) ephemeral_task = asyncio.create_task( self._get_ephemeral( conn_name.project, conn_name.instance_name, pub_key, enable_iam_auth, ) ) try: metadata = await metadata_task # check if automatic IAM database authn is supported for database engine if enable_iam_auth and not metadata["database_version"].startswith( ("POSTGRES", "MYSQL") ): raise AutoIAMAuthNotSupported( f"'{metadata['database_version']}' does not support " "automatic IAM authentication. It is only supported with " "Cloud SQL Postgres or MySQL instances." ) except Exception: # cancel ephemeral cert task if exception occurs before it is awaited ephemeral_task.cancel() raise ephemeral_cert, expiration = await ephemeral_task return ConnectionInfo( conn_name, ephemeral_cert, metadata["server_ca_cert"], priv_key, metadata["ip_addresses"], metadata["database_version"], expiration, ) async def close(self) -> None: """Close CloudSQLClient gracefully.""" logger.debug("Waiting for Connector's http client to close") await self._client.close() logger.debug("Closed Connector's http client")