google/cloud/sql/connector/connector.py (299 lines of code) (raw):
"""
Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations
import asyncio
from functools import partial
import logging
import os
import socket
from threading import Thread
from types import TracebackType
from typing import Any, Callable, Optional, Union
import google.auth
from google.auth.credentials import Credentials
from google.auth.credentials import with_scopes_if_required
import google.cloud.sql.connector.asyncpg as asyncpg
from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.enums import DriverMapping
from google.cloud.sql.connector.enums import IPTypes
from google.cloud.sql.connector.enums import RefreshStrategy
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.lazy import LazyRefreshCache
from google.cloud.sql.connector.monitored_cache import MonitoredCache
import google.cloud.sql.connector.pg8000 as pg8000
import google.cloud.sql.connector.pymysql as pymysql
import google.cloud.sql.connector.pytds as pytds
from google.cloud.sql.connector.resolver import DefaultResolver
from google.cloud.sql.connector.resolver import DnsResolver
from google.cloud.sql.connector.utils import format_database_user
from google.cloud.sql.connector.utils import generate_keys
logger = logging.getLogger(name=__name__)
ASYNC_DRIVERS = ["asyncpg"]
SERVER_PROXY_PORT = 3307
_DEFAULT_SCHEME = "https://"
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
class Connector:
"""Configure and create secure connections to Cloud SQL."""
def __init__(
self,
ip_type: str | IPTypes = IPTypes.PUBLIC,
enable_iam_auth: bool = False,
timeout: int = 30,
credentials: Optional[Credentials] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
quota_project: Optional[str] = None,
sqladmin_api_endpoint: Optional[str] = None,
user_agent: Optional[str] = None,
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
failover_period: int = 30,
) -> None:
"""Initializes a Connector instance.
Args:
ip_type (str | IPTypes): The default IP address type used to connect to
Cloud SQL instances. Can be one of the following:
IPTypes.PUBLIC ("PUBLIC"), IPTypes.PRIVATE ("PRIVATE"), or
IPTypes.PSC ("PSC"). Default: IPTypes.PUBLIC
enable_iam_auth (bool): Enables automatic IAM database authentication
(Postgres and MySQL) as the default authentication method for all
connections.
timeout (int): The default time limit in seconds for a connection before
raising a TimeoutError.
credentials (google.auth.credentials.Credentials): A credentials object
created from the google-auth Python library to be used.
If not specified, Application Default Credentials (ADC) are used.
quota_project (str): The Project ID for an existing Google Cloud
project. The project specified is used for quota and billing
purposes. If not specified, defaults to project sourced from
environment.
loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks, if
not specified, defaults to creating new event loop on background
thread.
sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL
Admin API endpoint. Defaults to "https://sqladmin.googleapis.com",
this argument should only be used in development.
universe_domain (str): The universe domain for Cloud SQL API calls.
Default: "googleapis.com".
refresh_strategy (str | RefreshStrategy): The default refresh strategy
used to refresh SSL/TLS cert and instance metadata. Can be one
of the following: RefreshStrategy.LAZY ("LAZY") or
RefreshStrategy.BACKGROUND ("BACKGROUND").
Default: RefreshStrategy.BACKGROUND
resolver (DefaultResolver | DnsResolver): The class name of the
resolver to use for resolving the Cloud SQL instance connection
name. To resolve a DNS record to an instance connection name, use
DnsResolver.
Default: DefaultResolver
failover_period (int): The time interval in seconds between each
attempt to check if a failover has occured for a given instance.
Must be used with `resolver=DnsResolver` to have any effect.
Default: 30
"""
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
refresh_strategy = RefreshStrategy._from_str(refresh_strategy)
self._refresh_strategy = refresh_strategy
# if event loop is given, use for background tasks
if loop:
self._loop: asyncio.AbstractEventLoop = loop
self._thread: Optional[Thread] = None
# if lazy refresh is specified we should lazy init keys
if self._refresh_strategy == RefreshStrategy.LAZY:
self._keys: Optional[asyncio.Future] = None
else:
self._keys = loop.create_task(generate_keys())
# if no event loop is given, spin up new loop in background thread
else:
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=self._loop.run_forever, daemon=True)
self._thread.start()
# if lazy refresh is specified we should lazy init keys
if self._refresh_strategy == RefreshStrategy.LAZY:
self._keys = None
else:
self._keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self._loop,
)
# initialize dict to store caches, key is a tuple consisting of instance
# connection name string and enable_iam_auth boolean flag
self._cache: dict[tuple[str, bool], MonitoredCache] = {}
self._client: Optional[CloudSQLClient] = None
# initialize credentials
scopes = ["https://www.googleapis.com/auth/sqlservice.admin"]
if credentials:
# verify custom credentials are proper type
# and atleast base class of google.auth.credentials
if not isinstance(credentials, Credentials):
raise TypeError(
"credentials must be of type google.auth.credentials.Credentials,"
f" got {type(credentials)}"
)
self._credentials = with_scopes_if_required(credentials, scopes=scopes)
# otherwise use application default credentials
else:
self._credentials, _ = google.auth.default(scopes=scopes)
# set default params for connections
self._timeout = timeout
self._enable_iam_auth = enable_iam_auth
self._user_agent = user_agent
self._resolver = resolver()
self._failover_period = failover_period
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
self._ip_type = ip_type
# check for quota project arg and then env var
if quota_project:
self._quota_project = quota_project
else:
self._quota_project = os.environ.get("GOOGLE_CLOUD_QUOTA_PROJECT") # type: ignore
# check for universe domain arg and then env var
if universe_domain:
self._universe_domain = universe_domain
else:
self._universe_domain = os.environ.get("GOOGLE_CLOUD_UNIVERSE_DOMAIN") # type: ignore
# construct service endpoint for Cloud SQL Admin API calls
if not sqladmin_api_endpoint:
self._sqladmin_api_endpoint = (
_DEFAULT_SCHEME
+ _SQLADMIN_HOST_TEMPLATE.format(universe_domain=self.universe_domain)
)
# otherwise if endpoint override is passed in use it
else:
self._sqladmin_api_endpoint = sqladmin_api_endpoint
# validate that the universe domain of the credentials matches the
# universe domain of the service endpoint
if self._credentials.universe_domain != self.universe_domain:
raise ValueError(
f"The configured universe domain ({self.universe_domain}) does "
"not match the universe domain found in the credentials "
f"({self._credentials.universe_domain}). If you haven't "
"configured the universe domain explicitly, `googleapis.com` "
"is the default."
)
@property
def universe_domain(self) -> str:
return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN
def connect(
self, instance_connection_string: str, driver: str, **kwargs: Any
) -> Any:
"""Connect to a Cloud SQL instance.
Prepares and returns a database connection object connected to a Cloud
SQL instance using SSL/TLS. Starts a background refresh to periodically
retrieve up-to-date ephemeral certificate and instance metadata.
Args:
instance_connection_string (str): The instance connection name of the
Cloud SQL instance to connect to. Takes the form of
"project-id:region:instance-name"
Example: "my-project:us-central1:my-instance"
driver (str): A string representing the database driver to connect
with. Supported drivers are pymysql, pg8000, and pytds.
**kwargs: Any driver-specific arguments to pass to the underlying
driver .connect call.
Returns:
A DB-API connection to the specified Cloud SQL instance.
"""
# connect runs sync database connections on background thread.
# Async database connections should call 'connect_async' directly to
# avoid hanging indefinitely.
connect_future = asyncio.run_coroutine_threadsafe(
self.connect_async(instance_connection_string, driver, **kwargs),
self._loop,
)
return connect_future.result()
async def connect_async(
self, instance_connection_string: str, driver: str, **kwargs: Any
) -> Any:
"""Connect asynchronously to a Cloud SQL instance.
Prepares and returns a database connection object connected to a Cloud
SQL instance using SSL/TLS. Schedules a refresh to periodically
retrieve up-to-date ephemeral certificate and instance metadata. Async
version of Connector.connect.
Args:
instance_connection_string (str): The instance connection name of the
Cloud SQL instance to connect to. Takes the form of
"project-id:region:instance-name"
Example: "my-project:us-central1:my-instance"
driver (str): A string representing the database driver to connect
with. Supported drivers are pymysql, asyncpg, pg8000, and pytds.
**kwargs: Any driver-specific arguments to pass to the underlying
driver .connect call.
Returns:
A DB-API connection to the specified Cloud SQL instance.
Raises:
ValueError: Connection attempt with built-in database authentication
and then subsequent attempt with IAM database authentication.
KeyError: Unsupported database driver Must be one of pymysql, asyncpg,
pg8000, and pytds.
"""
if self._keys is None:
self._keys = asyncio.create_task(generate_keys())
if self._client is None:
# lazy init client as it has to be initialized in async context
self._client = CloudSQLClient(
self._sqladmin_api_endpoint,
self._quota_project,
self._credentials,
user_agent=self._user_agent,
driver=driver,
)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
conn_name = await self._resolver.resolve(instance_connection_string)
# Cache entry must exist and not be closed
if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[
(str(conn_name), enable_iam_auth)
].closed:
monitored_cache = self._cache[(str(conn_name), enable_iam_auth)]
else:
if self._refresh_strategy == RefreshStrategy.LAZY:
logger.debug(
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
)
cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache(
conn_name,
self._client,
self._keys,
enable_iam_auth,
)
else:
logger.debug(
f"['{conn_name}']: Refresh strategy is set to backgound refresh"
)
cache = RefreshAheadCache(
conn_name,
self._client,
self._keys,
enable_iam_auth,
)
# wrap cache as a MonitoredCache
monitored_cache = MonitoredCache(
cache,
self._failover_period,
self._resolver,
)
logger.debug(f"['{conn_name}']: Connection info added to cache")
self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache
connect_func = {
"pymysql": pymysql.connect,
"pg8000": pg8000.connect,
"asyncpg": asyncpg.connect,
"pytds": pytds.connect,
}
# only accept supported database drivers
try:
connector: Callable = connect_func[driver] # type: ignore
except KeyError:
raise KeyError(f"Driver '{driver}' is not supported.")
ip_type = kwargs.pop("ip_type", self._ip_type)
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
kwargs["timeout"] = kwargs.get("timeout", self._timeout)
# Host and ssl options come from the certificates and metadata, so we don't
# want the user to specify them.
kwargs.pop("host", None)
kwargs.pop("ssl", None)
kwargs.pop("port", None)
# attempt to get connection info for Cloud SQL instance
try:
conn_info = await monitored_cache.connect_info()
# validate driver matches intended database engine
DriverMapping.validate_engine(driver, conn_info.database_version)
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from Cloud SQL Admin API call or IP type, invalidate
# the cache and re-raise the error
await self._remove_cached(str(conn_name), enable_iam_auth)
raise
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
conn_info.database_version, kwargs["user"]
)
if formatted_user != kwargs["user"]:
logger.debug(
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
)
kwargs["user"] = formatted_user
try:
# async drivers are unblocking and can be awaited directly
if driver in ASYNC_DRIVERS:
return await connector(
ip_address,
await conn_info.create_ssl_context(enable_iam_auth),
**kwargs,
)
# Create socket with SSLContext for sync drivers
ctx = await conn_info.create_ssl_context(enable_iam_auth)
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)
# If this connection was opened using a domain name, then store it
# for later in case we need to forcibly close it on failover.
if conn_info.conn_name.domain_name:
monitored_cache.sockets.append(sock)
# Synchronous drivers are blocking and run using executor
connect_partial = partial(
connector,
ip_address,
sock,
**kwargs,
)
return await self._loop.run_in_executor(None, connect_partial)
except Exception:
# with any exception, we attempt a force refresh, then throw the error
await monitored_cache.force_refresh()
raise
async def _remove_cached(
self, instance_connection_string: str, enable_iam_auth: bool
) -> None:
"""Stops all background refreshes and deletes the connection
info cache from the map of caches.
"""
logger.debug(
f"['{instance_connection_string}']: Removing connection info from cache"
)
# remove cache from stored caches and close it
cache = self._cache.pop((instance_connection_string, enable_iam_auth))
await cache.close()
def __enter__(self) -> Any:
"""Enter context manager by returning Connector object"""
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit context manager by closing Connector"""
self.close()
async def __aenter__(self) -> Any:
"""Enter async context manager by returning Connector object"""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit async context manager by closing Connector"""
await self.close_async()
def close(self) -> None:
"""Close Connector by stopping tasks and releasing resources."""
if self._loop.is_running():
close_future = asyncio.run_coroutine_threadsafe(
self.close_async(), loop=self._loop
)
# Will attempt to safely shut down tasks for 3s
close_future.result(timeout=3)
# if background thread exists for Connector, clean it up
if self._thread:
if self._loop.is_running():
# stop event loop running in background thread
self._loop.call_soon_threadsafe(self._loop.stop)
# wait for thread to finish closing (i.e. loop to stop)
self._thread.join()
async def close_async(self) -> None:
"""Helper function to cancel the cache's tasks
and close aiohttp.ClientSession."""
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
if self._client:
await self._client.close()
async def create_async_connector(
ip_type: str | IPTypes = IPTypes.PUBLIC,
enable_iam_auth: bool = False,
timeout: int = 30,
credentials: Optional[Credentials] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
quota_project: Optional[str] = None,
sqladmin_api_endpoint: Optional[str] = None,
user_agent: Optional[str] = None,
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
failover_period: int = 30,
) -> Connector:
"""Helper function to create Connector object for asyncio connections.
Force use of Connector in an asyncio context. Auto-detect and use current
thread's running event loop.
Args:
ip_type (str | IPTypes): The default IP address type used to connect to
Cloud SQL instances. Can be one of the following:
IPTypes.PUBLIC ("PUBLIC"), IPTypes.PRIVATE ("PRIVATE"), or
IPTypes.PSC ("PSC"). Default: IPTypes.PUBLIC
enable_iam_auth (bool): Enables automatic IAM database authentication
(Postgres and MySQL) as the default authentication method for all
connections.
timeout (int): The default time limit in seconds for a connection before
raising a TimeoutError.
credentials (google.auth.credentials.Credentials): A credentials object
created from the google-auth Python library to be used.
If not specified, Application Default Credentials (ADC) are used.
quota_project (str): The Project ID for an existing Google Cloud
project. The project specified is used for quota and billing
purposes. If not specified, defaults to project sourced from
environment.
loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks, if
not specified, defaults to creating new event loop on background
thread.
sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL
Admin API endpoint. Defaults to "https://sqladmin.googleapis.com",
this argument should only be used in development.
universe_domain (str): The universe domain for Cloud SQL API calls.
Default: "googleapis.com".
refresh_strategy (str | RefreshStrategy): The default refresh strategy
used to refresh SSL/TLS cert and instance metadata. Can be one
of the following: RefreshStrategy.LAZY ("LAZY") or
RefreshStrategy.BACKGROUND ("BACKGROUND").
Default: RefreshStrategy.BACKGROUND
resolver (DefaultResolver | DnsResolver): The class name of the
resolver to use for resolving the Cloud SQL instance connection
name. To resolve a DNS record to an instance connection name, use
DnsResolver.
Default: DefaultResolver
failover_period (int): The time interval in seconds between each
attempt to check if a failover has occured for a given instance.
Must be used with `resolver=DnsResolver` to have any effect.
Default: 30
Returns:
A Connector instance configured with running event loop.
"""
# if no loop given, automatically detect running event loop
if loop is None:
loop = asyncio.get_running_loop()
return Connector(
ip_type=ip_type,
enable_iam_auth=enable_iam_auth,
timeout=timeout,
credentials=credentials,
loop=loop,
quota_project=quota_project,
sqladmin_api_endpoint=sqladmin_api_endpoint,
user_agent=user_agent,
universe_domain=universe_domain,
refresh_strategy=refresh_strategy,
resolver=resolver,
failover_period=failover_period,
)