google/cloud/alloydb/connector/async_connector.py (141 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import asyncio import logging from types import TracebackType from typing import Any, Optional, TYPE_CHECKING import google.auth from google.auth.credentials import with_scopes_if_required import google.auth.transport.requests import google.cloud.alloydb.connector.asyncpg as asyncpg from google.cloud.alloydb.connector.client import AlloyDBClient from google.cloud.alloydb.connector.enums import IPTypes from google.cloud.alloydb.connector.enums import RefreshStrategy from google.cloud.alloydb.connector.exceptions import ClosedConnectorError from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache from google.cloud.alloydb.connector.types import CacheTypes from google.cloud.alloydb.connector.utils import generate_keys from google.cloud.alloydb.connector.utils import strip_http_prefix if TYPE_CHECKING: from google.auth.credentials import Credentials logger = logging.getLogger(name=__name__) class AsyncConnector: """A class to configure and create connections to Cloud SQL instances asynchronously. Args: credentials (google.auth.credentials.Credentials): A credentials object created from the google-auth Python library. If not specified, Application Default Credentials are used. quota_project (str): The Project ID for an existing Google Cloud project. The project specified is used for quota and billing purposes. Defaults to None, picking up project from environment. alloydb_api_endpoint (str): Base URL to use when calling the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com". enable_iam_auth (bool): Enables automatic IAM database authentication. ip_type (str | IPTypes): Default IP type for all AlloyDB connections. Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections. refresh_strategy (str | RefreshStrategy): The default refresh strategy used to refresh SSL/TLS cert and instance metadata. Can be one of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND """ def __init__( self, credentials: Optional[Credentials] = None, quota_project: Optional[str] = None, alloydb_api_endpoint: str = "alloydb.googleapis.com", enable_iam_auth: bool = False, ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, ) -> None: self._cache: dict[str, CacheTypes] = {} # initialize default params self._quota_project = quota_project self._alloydb_api_endpoint = strip_http_prefix(alloydb_api_endpoint) self._enable_iam_auth = enable_iam_auth # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes(ip_type.upper()) self._ip_type = ip_type # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): refresh_strategy = RefreshStrategy(refresh_strategy.upper()) self._refresh_strategy = refresh_strategy self._user_agent = user_agent # initialize credentials scopes = ["https://www.googleapis.com/auth/cloud-platform"] if credentials: self._credentials = with_scopes_if_required(credentials, scopes=scopes) # otherwise use application default credentials else: self._credentials, _ = google.auth.default(scopes=scopes) # check if AsyncConnector is being initialized with event loop running # Otherwise we will lazy init keys try: self._keys: Optional[asyncio.Task] = asyncio.create_task(generate_keys()) except RuntimeError: self._keys = None self._client: Optional[AlloyDBClient] = None self._closed = False async def connect( self, instance_uri: str, driver: str, **kwargs: Any, ) -> Any: """ Asynchronously prepares and returns a database connection object. Starts tasks to refresh the certificates and get AlloyDB instance IP address. Creates a secure TLS connection to establish connection to AlloyDB instance. Args: instance_uri (str): The instance URI of the AlloyDB instance. ex. projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE> driver (str): A string representing the database driver to connect with. Supported drivers are asyncpg. **kwargs: Pass in any database driver-specific arguments needed to fine tune connection. Returns: connection: A DBAPI connection to the specified AlloyDB instance. """ if self._closed: raise ClosedConnectorError( "Connection attempt failed because the connector has already been closed." ) if self._keys is None: self._keys = asyncio.create_task(generate_keys()) if self._client is None: # lazy init client as it has to be initialized in async context self._client = AlloyDBClient( self._alloydb_api_endpoint, self._quota_project, self._credentials, user_agent=self._user_agent, driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible if instance_uri in self._cache: cache = self._cache[instance_uri] else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{instance_uri}']: Refresh strategy is set to lazy refresh" ) cache = LazyRefreshCache(instance_uri, self._client, self._keys) else: logger.debug( f"['{instance_uri}']: Refresh strategy is set to background" " refresh" ) cache = RefreshAheadCache(instance_uri, self._client, self._keys) self._cache[instance_uri] = cache logger.debug(f"['{instance_uri}']: Connection info added to cache") connect_func = { "asyncpg": asyncpg.connect, } # only accept supported database drivers try: connector = connect_func[driver] except KeyError: raise ValueError(f"Driver '{driver}' is not a supported database driver.") # Host and ssl options come from the certificates and instance IP # address so we don't want the user to specify them. kwargs.pop("host", None) kwargs.pop("ssl", None) kwargs.pop("port", None) # get connection info for AlloyDB instance ip_type: str | IPTypes = kwargs.pop("ip_type", self._ip_type) # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes(ip_type.upper()) try: conn_info = await cache.connect_info() ip_address = conn_info.get_preferred_ip(ip_type) except Exception: # with an error from AlloyDB API call or IP type, invalidate the # cache and re-raise the error await self._remove_cached(instance_uri) raise logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") # callable to be used for auto IAM authn def get_authentication_token() -> str: """Get OAuth2 access token to be used for IAM database authentication""" # refresh credentials if expired if not self._credentials.valid: request = google.auth.transport.requests.Request() self._credentials.refresh(request) return self._credentials.token # if enable_iam_auth is set, use auth token as database password if enable_iam_auth: kwargs["password"] = get_authentication_token try: return await connector( ip_address, await conn_info.create_ssl_context(), **kwargs ) except Exception: # we attempt a force refresh, then throw the error await cache.force_refresh() raise async def _remove_cached(self, instance_uri: str) -> None: """Stops all background refreshes and deletes the connection info cache from the map of caches. """ logger.debug(f"['{instance_uri}']: Removing connection info from cache") # remove cache from stored caches and close it cache = self._cache.pop(instance_uri) await cache.close() async def __aenter__(self) -> Any: """Enter async context manager by returning Connector object""" return self async def __aexit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: """Exit async context manager by closing Connector""" await self.close() async def close(self) -> None: """Helper function to cancel RefreshAheadCaches' tasks and close client.""" await asyncio.gather(*[cache.close() for cache in self._cache.values()]) self._closed = True