google/cloud/alloydb/connector/connector.py (211 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
from functools import partial
import io
import logging
import socket
import struct
from threading import Thread
from types import TracebackType
from typing import Any, Optional, TYPE_CHECKING
from google.auth import default
from google.auth.credentials import TokenState
from google.auth.credentials import with_scopes_if_required
from google.auth.transport import requests
from google.cloud.alloydb.connector.client import AlloyDBClient
from google.cloud.alloydb.connector.enums import IPTypes
from google.cloud.alloydb.connector.enums import RefreshStrategy
from google.cloud.alloydb.connector.exceptions import ClosedConnectorError
from google.cloud.alloydb.connector.instance import RefreshAheadCache
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
import google.cloud.alloydb.connector.pg8000 as pg8000
from google.cloud.alloydb.connector.types import CacheTypes
from google.cloud.alloydb.connector.utils import generate_keys
from google.cloud.alloydb.connector.utils import strip_http_prefix
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb
if TYPE_CHECKING:
import ssl
from google.auth.credentials import Credentials
logger = logging.getLogger(name=__name__)
# the port the AlloyDB server-side proxy receives connections on
SERVER_PROXY_PORT = 5433
# the maximum amount of time to wait before aborting a metadata exchange
IO_TIMEOUT = 30
class Connector:
"""A class to configure and create connections to Cloud SQL instances.
Args:
credentials (google.auth.credentials.Credentials):
A credentials object created from the google-auth Python library.
If not specified, Application Default Credentials are used.
quota_project (str): The Project ID for an existing Google Cloud
project. The project specified is used for quota and
billing purposes.
Defaults to None, picking up project from environment.
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
refresh_strategy (str | RefreshStrategy): The default refresh strategy
used to refresh SSL/TLS cert and instance metadata. Can be one
of the following: RefreshStrategy.LAZY ("LAZY") or
RefreshStrategy.BACKGROUND ("BACKGROUND").
Default: RefreshStrategy.BACKGROUND
static_conn_info (io.TextIOBase): A file-like JSON object that contains
static connection info for the StaticConnectionInfoCache.
Defaults to None, which will not use the StaticConnectionInfoCache.
This is a *dev-only* option and should not be used in production as
it will result in failed connections after the client certificate
expires.
"""
def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
static_conn_info: Optional[io.TextIOBase] = None,
) -> None:
# create event loop and start it in background thread
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread = Thread(target=self._loop.run_forever, daemon=True)
self._thread.start()
self._cache: dict[str, CacheTypes] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = strip_http_prefix(alloydb_api_endpoint)
self._enable_iam_auth = enable_iam_auth
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
self._ip_type = ip_type
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
refresh_strategy = RefreshStrategy(refresh_strategy.upper())
self._refresh_strategy = refresh_strategy
self._user_agent = user_agent
# initialize credentials
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
if credentials:
self._credentials = with_scopes_if_required(credentials, scopes=scopes)
# otherwise use application default credentials
else:
self._credentials, _ = default(scopes=scopes)
self._keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self._loop,
)
self._client: Optional[AlloyDBClient] = None
self._static_conn_info = static_conn_info
self._closed = False
def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
"""
Prepares and returns a database DBAPI connection object.
Starts background tasks to refresh the certificates and get
AlloyDB instance IP address. Creates a secure TLS connection
to establish connection to AlloyDB instance.
Args:
instance_uri (str): The instance URI of the AlloyDB instance.
ex. projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>
driver (str): A string representing the database driver to connect with.
Supported drivers are pg8000.
**kwargs: Pass in any database driver-specific arguments needed
to fine tune connection.
Returns:
connection: A DBAPI connection to the specified AlloyDB instance.
"""
if self._closed:
raise ClosedConnectorError(
"Connection attempt failed because the connector has already been closed."
)
# call async connect and wait on result
connect_task = asyncio.run_coroutine_threadsafe(
self.connect_async(instance_uri, driver, **kwargs), self._loop
)
return connect_task.result()
async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
"""
Asynchronously prepares and returns a database connection object.
Starts tasks to refresh the certificates and get
AlloyDB instance IP address. Creates a secure TLS connection
to establish connection to AlloyDB instance.
Args:
instance_uri (str): The instance URI of the AlloyDB instance.
ex. projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>
driver (str): A string representing the database driver to connect with.
Supported drivers are pg8000.
**kwargs: Pass in any database driver-specific arguments needed
to fine tune connection.
Returns:
connection: A DBAPI connection to the specified AlloyDB instance.
"""
if self._client is None:
# lazy init client as it has to be initialized in async context
self._client = AlloyDBClient(
self._alloydb_api_endpoint,
self._quota_project,
self._credentials,
user_agent=self._user_agent,
driver=driver,
)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
# use existing connection info if possible
if instance_uri in self._cache:
cache = self._cache[instance_uri]
else:
if self._refresh_strategy == RefreshStrategy.LAZY:
logger.debug(
f"['{instance_uri}']: Refresh strategy is set to lazy refresh"
)
cache = LazyRefreshCache(instance_uri, self._client, self._keys)
else:
logger.debug(
f"['{instance_uri}']: Refresh strategy is set to background"
" refresh"
)
cache = RefreshAheadCache(instance_uri, self._client, self._keys)
self._cache[instance_uri] = cache
logger.debug(f"['{instance_uri}']: Connection info added to cache")
connect_func = {
"pg8000": pg8000.connect,
}
# only accept supported database drivers
try:
connector = connect_func[driver]
except KeyError:
raise ValueError(f"Driver '{driver}' is not a supported database driver.")
# Host and ssl options come from the certificates and instance IP address
# so we don't want the user to specify them.
kwargs.pop("host", None)
kwargs.pop("ssl", None)
kwargs.pop("port", None)
# get connection info for AlloyDB instance
ip_type: IPTypes | str = kwargs.pop("ip_type", self._ip_type)
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
try:
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from AlloyDB API call or IP type, invalidate the
# cache and re-raise the error
await self._remove_cached(instance_uri)
raise
logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433")
# synchronous drivers are blocking and run using executor
try:
metadata_partial = partial(
self.metadata_exchange,
ip_address,
await conn_info.create_ssl_context(),
enable_iam_auth,
)
sock = await self._loop.run_in_executor(None, metadata_partial)
connect_partial = partial(connector, sock, **kwargs)
return await self._loop.run_in_executor(None, connect_partial)
except Exception:
# we attempt a force refresh, then throw the error
await cache.force_refresh()
raise
def metadata_exchange(
self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool
) -> ssl.SSLSocket:
"""
Sends metadata about the connection prior to the database
protocol taking over.
The exchange consists of four steps:
1. Prepare a MetadataExchangeRequest including the IAM Principal's
OAuth2 token, the user agent, and the requested authentication type.
2. Write the size of the message as a big endian uint32 (4 bytes) to
the server followed by the serialized message. The length does not
include the initial four bytes.
3. Read a big endian uint32 (4 bytes) from the server. This is the
MetadataExchangeResponse message length and does not include the
initial four bytes.
4. Parse the response using the message length in step 3. If the
response is not OK, return the response's error. If there is no error,
the metadata exchange has succeeded and the connection is complete.
Args:
ip_address (str): IP address of AlloyDB instance to connect to.
ctx (ssl.SSLContext): Context used to create a TLS connection
with AlloyDB instance ssl certificates.
enable_iam_auth (bool): Flag to enable IAM database authentication.
Returns:
sock (ssl.SSLSocket): mTLS/SSL socket connected to AlloyDB Proxy server.
"""
# Create socket and wrap with SSL/TLS context
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)
# set auth type for metadata exchange
auth_type = connectorspb.MetadataExchangeRequest.DB_NATIVE
if enable_iam_auth:
auth_type = connectorspb.MetadataExchangeRequest.AUTO_IAM
# Ensure the credentials are in fact valid before proceeding.
if not self._credentials.token_state == TokenState.FRESH:
self._credentials.refresh(requests.Request())
# form metadata exchange request
req = connectorspb.MetadataExchangeRequest(
user_agent=f"{self._client._user_agent}", # type: ignore
auth_type=auth_type,
oauth2_token=self._credentials.token,
)
# set I/O timeout
sock.settimeout(IO_TIMEOUT)
# pack big-endian unsigned integer (4 bytes)
packed_len = struct.pack(">I", req.ByteSize())
# send metadata message length and request message
sock.sendall(packed_len + req.SerializeToString())
# form metadata exchange response
resp = connectorspb.MetadataExchangeResponse()
# read metadata message length (4 bytes)
message_len_buffer_size = struct.Struct(">I").size
message_len_buffer = b""
while message_len_buffer_size > 0:
chunk = sock.recv(message_len_buffer_size)
if not chunk:
raise RuntimeError(
"Connection closed while getting metadata exchange length!"
)
message_len_buffer += chunk
message_len_buffer_size -= len(chunk)
(message_len,) = struct.unpack(">I", message_len_buffer)
# read metadata exchange message
buffer = b""
while message_len > 0:
chunk = sock.recv(message_len)
if not chunk:
raise RuntimeError(
"Connection closed while performing metadata exchange!"
)
buffer += chunk
message_len -= len(chunk)
# parse metadata exchange response from buffer
resp.ParseFromString(buffer)
# reset socket back to blocking mode
sock.setblocking(True)
# validate metadata exchange response
if resp.response_code != connectorspb.MetadataExchangeResponse.OK:
raise ValueError(
f"Metadata Exchange request has failed with error: {resp.error}"
)
return sock
async def _remove_cached(self, instance_uri: str) -> None:
"""Stops all background refreshes and deletes the connection
info cache from the map of caches.
"""
logger.debug(f"['{instance_uri}']: Removing connection info from cache")
# remove cache from stored caches and close it
cache = self._cache.pop(instance_uri)
await cache.close()
def __enter__(self) -> "Connector":
"""Enter context manager by returning Connector object"""
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit context manager by closing Connector"""
self.close()
def close(self) -> None:
"""Close Connector by stopping tasks and releasing resources."""
if self._loop.is_running():
close_future = asyncio.run_coroutine_threadsafe(
self.close_async(), loop=self._loop
)
# Will attempt to gracefully shut down tasks for 3s
close_future.result(timeout=3)
# if background thread exists for Connector, clean it up
if self._thread.is_alive():
if self._loop.is_running():
# stop event loop running in background thread
self._loop.call_soon_threadsafe(self._loop.stop)
# wait for thread to finish closing (i.e. loop to stop)
self._thread.join()
self._closed = True
async def close_async(self) -> None:
"""Helper function to cancel RefreshAheadCaches' tasks
and close client."""
await asyncio.gather(*[cache.close() for cache in self._cache.values()])