google/cloud/sql/connector/connector.py (299 lines of code) (raw):

""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from __future__ import annotations import asyncio from functools import partial import logging import os import socket from threading import Thread from types import TracebackType from typing import Any, Callable, Optional, Union import google.auth from google.auth.credentials import Credentials from google.auth.credentials import with_scopes_if_required import google.cloud.sql.connector.asyncpg as asyncpg from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.enums import DriverMapping from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.enums import RefreshStrategy from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" _SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}" class Connector: """Configure and create secure connections to Cloud SQL.""" def __init__( self, ip_type: str | IPTypes = IPTypes.PUBLIC, enable_iam_auth: bool = False, timeout: int = 30, credentials: Optional[Credentials] = None, loop: Optional[asyncio.AbstractEventLoop] = None, quota_project: Optional[str] = None, sqladmin_api_endpoint: Optional[str] = None, user_agent: Optional[str] = None, universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, failover_period: int = 30, ) -> None: """Initializes a Connector instance. Args: ip_type (str | IPTypes): The default IP address type used to connect to Cloud SQL instances. Can be one of the following: IPTypes.PUBLIC ("PUBLIC"), IPTypes.PRIVATE ("PRIVATE"), or IPTypes.PSC ("PSC"). Default: IPTypes.PUBLIC enable_iam_auth (bool): Enables automatic IAM database authentication (Postgres and MySQL) as the default authentication method for all connections. timeout (int): The default time limit in seconds for a connection before raising a TimeoutError. credentials (google.auth.credentials.Credentials): A credentials object created from the google-auth Python library to be used. If not specified, Application Default Credentials (ADC) are used. quota_project (str): The Project ID for an existing Google Cloud project. The project specified is used for quota and billing purposes. If not specified, defaults to project sourced from environment. loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks, if not specified, defaults to creating new event loop on background thread. sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL Admin API endpoint. Defaults to "https://sqladmin.googleapis.com", this argument should only be used in development. universe_domain (str): The universe domain for Cloud SQL API calls. Default: "googleapis.com". refresh_strategy (str | RefreshStrategy): The default refresh strategy used to refresh SSL/TLS cert and instance metadata. Can be one of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND resolver (DefaultResolver | DnsResolver): The class name of the resolver to use for resolving the Cloud SQL instance connection name. To resolve a DNS record to an instance connection name, use DnsResolver. Default: DefaultResolver failover_period (int): The time interval in seconds between each attempt to check if a failover has occured for a given instance. Must be used with `resolver=DnsResolver` to have any effect. Default: 30 """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): refresh_strategy = RefreshStrategy._from_str(refresh_strategy) self._refresh_strategy = refresh_strategy # if event loop is given, use for background tasks if loop: self._loop: asyncio.AbstractEventLoop = loop self._thread: Optional[Thread] = None # if lazy refresh is specified we should lazy init keys if self._refresh_strategy == RefreshStrategy.LAZY: self._keys: Optional[asyncio.Future] = None else: self._keys = loop.create_task(generate_keys()) # if no event loop is given, spin up new loop in background thread else: self._loop = asyncio.new_event_loop() self._thread = Thread(target=self._loop.run_forever, daemon=True) self._thread.start() # if lazy refresh is specified we should lazy init keys if self._refresh_strategy == RefreshStrategy.LAZY: self._keys = None else: self._keys = asyncio.wrap_future( asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), loop=self._loop, ) # initialize dict to store caches, key is a tuple consisting of instance # connection name string and enable_iam_auth boolean flag self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] if credentials: # verify custom credentials are proper type # and atleast base class of google.auth.credentials if not isinstance(credentials, Credentials): raise TypeError( "credentials must be of type google.auth.credentials.Credentials," f" got {type(credentials)}" ) self._credentials = with_scopes_if_required(credentials, scopes=scopes) # otherwise use application default credentials else: self._credentials, _ = google.auth.default(scopes=scopes) # set default params for connections self._timeout = timeout self._enable_iam_auth = enable_iam_auth self._user_agent = user_agent self._resolver = resolver() self._failover_period = failover_period # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) self._ip_type = ip_type # check for quota project arg and then env var if quota_project: self._quota_project = quota_project else: self._quota_project = os.environ.get("GOOGLE_CLOUD_QUOTA_PROJECT") # type: ignore # check for universe domain arg and then env var if universe_domain: self._universe_domain = universe_domain else: self._universe_domain = os.environ.get("GOOGLE_CLOUD_UNIVERSE_DOMAIN") # type: ignore # construct service endpoint for Cloud SQL Admin API calls if not sqladmin_api_endpoint: self._sqladmin_api_endpoint = ( _DEFAULT_SCHEME + _SQLADMIN_HOST_TEMPLATE.format(universe_domain=self.universe_domain) ) # otherwise if endpoint override is passed in use it else: self._sqladmin_api_endpoint = sqladmin_api_endpoint # validate that the universe domain of the credentials matches the # universe domain of the service endpoint if self._credentials.universe_domain != self.universe_domain: raise ValueError( f"The configured universe domain ({self.universe_domain}) does " "not match the universe domain found in the credentials " f"({self._credentials.universe_domain}). If you haven't " "configured the universe domain explicitly, `googleapis.com` " "is the default." ) @property def universe_domain(self) -> str: return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN def connect( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: """Connect to a Cloud SQL instance. Prepares and returns a database connection object connected to a Cloud SQL instance using SSL/TLS. Starts a background refresh to periodically retrieve up-to-date ephemeral certificate and instance metadata. Args: instance_connection_string (str): The instance connection name of the Cloud SQL instance to connect to. Takes the form of "project-id:region:instance-name" Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect with. Supported drivers are pymysql, pg8000, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. Returns: A DB-API connection to the specified Cloud SQL instance. """ # connect runs sync database connections on background thread. # Async database connections should call 'connect_async' directly to # avoid hanging indefinitely. connect_future = asyncio.run_coroutine_threadsafe( self.connect_async(instance_connection_string, driver, **kwargs), self._loop, ) return connect_future.result() async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: """Connect asynchronously to a Cloud SQL instance. Prepares and returns a database connection object connected to a Cloud SQL instance using SSL/TLS. Schedules a refresh to periodically retrieve up-to-date ephemeral certificate and instance metadata. Async version of Connector.connect. Args: instance_connection_string (str): The instance connection name of the Cloud SQL instance to connect to. Takes the form of "project-id:region:instance-name" Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect with. Supported drivers are pymysql, asyncpg, pg8000, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. Returns: A DB-API connection to the specified Cloud SQL instance. Raises: ValueError: Connection attempt with built-in database authentication and then subsequent attempt with IAM database authentication. KeyError: Unsupported database driver Must be one of pymysql, asyncpg, pg8000, and pytds. """ if self._keys is None: self._keys = asyncio.create_task(generate_keys()) if self._client is None: # lazy init client as it has to be initialized in async context self._client = CloudSQLClient( self._sqladmin_api_endpoint, self._quota_project, self._credentials, user_agent=self._user_agent, driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) conn_name = await self._resolver.resolve(instance_connection_string) # Cache entry must exist and not be closed if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ (str(conn_name), enable_iam_auth) ].closed: monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{conn_name}']: Refresh strategy is set to lazy refresh" ) cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( conn_name, self._client, self._keys, enable_iam_auth, ) else: logger.debug( f"['{conn_name}']: Refresh strategy is set to backgound refresh" ) cache = RefreshAheadCache( conn_name, self._client, self._keys, enable_iam_auth, ) # wrap cache as a MonitoredCache monitored_cache = MonitoredCache( cache, self._failover_period, self._resolver, ) logger.debug(f"['{conn_name}']: Connection info added to cache") self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, "asyncpg": asyncpg.connect, "pytds": pytds.connect, } # only accept supported database drivers try: connector: Callable = connect_func[driver] # type: ignore except KeyError: raise KeyError(f"Driver '{driver}' is not supported.") ip_type = kwargs.pop("ip_type", self._ip_type) # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) kwargs["timeout"] = kwargs.get("timeout", self._timeout) # Host and ssl options come from the certificates and metadata, so we don't # want the user to specify them. kwargs.pop("host", None) kwargs.pop("ssl", None) kwargs.pop("port", None) # attempt to get connection info for Cloud SQL instance try: conn_info = await monitored_cache.connect_info() # validate driver matches intended database engine DriverMapping.validate_engine(driver, conn_info.database_version) ip_address = conn_info.get_preferred_ip(ip_type) except Exception: # with an error from Cloud SQL Admin API call or IP type, invalidate # the cache and re-raise the error await self._remove_cached(str(conn_name), enable_iam_auth) raise logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn if enable_iam_auth: formatted_user = format_database_user( conn_info.database_version, kwargs["user"] ) if formatted_user != kwargs["user"]: logger.debug( f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" ) kwargs["user"] = formatted_user try: # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: return await connector( ip_address, await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) # Create socket with SSLContext for sync drivers ctx = await conn_info.create_ssl_context(enable_iam_auth) sock = ctx.wrap_socket( socket.create_connection((ip_address, SERVER_PROXY_PORT)), server_hostname=ip_address, ) # If this connection was opened using a domain name, then store it # for later in case we need to forcibly close it on failover. if conn_info.conn_name.domain_name: monitored_cache.sockets.append(sock) # Synchronous drivers are blocking and run using executor connect_partial = partial( connector, ip_address, sock, **kwargs, ) return await self._loop.run_in_executor(None, connect_partial) except Exception: # with any exception, we attempt a force refresh, then throw the error await monitored_cache.force_refresh() raise async def _remove_cached( self, instance_connection_string: str, enable_iam_auth: bool ) -> None: """Stops all background refreshes and deletes the connection info cache from the map of caches. """ logger.debug( f"['{instance_connection_string}']: Removing connection info from cache" ) # remove cache from stored caches and close it cache = self._cache.pop((instance_connection_string, enable_iam_auth)) await cache.close() def __enter__(self) -> Any: """Enter context manager by returning Connector object""" return self def __exit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: """Exit context manager by closing Connector""" self.close() async def __aenter__(self) -> Any: """Enter async context manager by returning Connector object""" return self async def __aexit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: """Exit async context manager by closing Connector""" await self.close_async() def close(self) -> None: """Close Connector by stopping tasks and releasing resources.""" if self._loop.is_running(): close_future = asyncio.run_coroutine_threadsafe( self.close_async(), loop=self._loop ) # Will attempt to safely shut down tasks for 3s close_future.result(timeout=3) # if background thread exists for Connector, clean it up if self._thread: if self._loop.is_running(): # stop event loop running in background thread self._loop.call_soon_threadsafe(self._loop.stop) # wait for thread to finish closing (i.e. loop to stop) self._thread.join() async def close_async(self) -> None: """Helper function to cancel the cache's tasks and close aiohttp.ClientSession.""" await asyncio.gather(*[cache.close() for cache in self._cache.values()]) if self._client: await self._client.close() async def create_async_connector( ip_type: str | IPTypes = IPTypes.PUBLIC, enable_iam_auth: bool = False, timeout: int = 30, credentials: Optional[Credentials] = None, loop: Optional[asyncio.AbstractEventLoop] = None, quota_project: Optional[str] = None, sqladmin_api_endpoint: Optional[str] = None, user_agent: Optional[str] = None, universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, failover_period: int = 30, ) -> Connector: """Helper function to create Connector object for asyncio connections. Force use of Connector in an asyncio context. Auto-detect and use current thread's running event loop. Args: ip_type (str | IPTypes): The default IP address type used to connect to Cloud SQL instances. Can be one of the following: IPTypes.PUBLIC ("PUBLIC"), IPTypes.PRIVATE ("PRIVATE"), or IPTypes.PSC ("PSC"). Default: IPTypes.PUBLIC enable_iam_auth (bool): Enables automatic IAM database authentication (Postgres and MySQL) as the default authentication method for all connections. timeout (int): The default time limit in seconds for a connection before raising a TimeoutError. credentials (google.auth.credentials.Credentials): A credentials object created from the google-auth Python library to be used. If not specified, Application Default Credentials (ADC) are used. quota_project (str): The Project ID for an existing Google Cloud project. The project specified is used for quota and billing purposes. If not specified, defaults to project sourced from environment. loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks, if not specified, defaults to creating new event loop on background thread. sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL Admin API endpoint. Defaults to "https://sqladmin.googleapis.com", this argument should only be used in development. universe_domain (str): The universe domain for Cloud SQL API calls. Default: "googleapis.com". refresh_strategy (str | RefreshStrategy): The default refresh strategy used to refresh SSL/TLS cert and instance metadata. Can be one of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND resolver (DefaultResolver | DnsResolver): The class name of the resolver to use for resolving the Cloud SQL instance connection name. To resolve a DNS record to an instance connection name, use DnsResolver. Default: DefaultResolver failover_period (int): The time interval in seconds between each attempt to check if a failover has occured for a given instance. Must be used with `resolver=DnsResolver` to have any effect. Default: 30 Returns: A Connector instance configured with running event loop. """ # if no loop given, automatically detect running event loop if loop is None: loop = asyncio.get_running_loop() return Connector( ip_type=ip_type, enable_iam_auth=enable_iam_auth, timeout=timeout, credentials=credentials, loop=loop, quota_project=quota_project, sqladmin_api_endpoint=sqladmin_api_endpoint, user_agent=user_agent, universe_domain=universe_domain, refresh_strategy=refresh_strategy, resolver=resolver, failover_period=failover_period, )