google/cloud/alloydb/connector/instance.py (140 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import asyncio from datetime import datetime from datetime import timedelta from datetime import timezone import logging import re from typing import TYPE_CHECKING from google.cloud.alloydb.connector.connection_info import ConnectionInfo from google.cloud.alloydb.connector.exceptions import RefreshError from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter from google.cloud.alloydb.connector.refresh_utils import _is_valid from google.cloud.alloydb.connector.refresh_utils import _seconds_until_refresh if TYPE_CHECKING: from cryptography.hazmat.primitives.asymmetric import rsa from google.cloud.alloydb.connector.client import AlloyDBClient logger = logging.getLogger(name=__name__) INSTANCE_URI_REGEX = re.compile( "projects/([^:]+(:[^:]+)?)/locations/([^:]+)/clusters/([^:]+)/instances/([^:]+)" ) def _parse_instance_uri(instance_uri: str) -> tuple[str, str, str, str]: # should take form "projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>" if INSTANCE_URI_REGEX.fullmatch(instance_uri) is None: raise ValueError( "Arg `instance_uri` must have " "format: projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>, projects/<DOMAIN>:<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>" f"got {instance_uri}." ) instance_uri_split = INSTANCE_URI_REGEX.split(instance_uri) return ( instance_uri_split[1], instance_uri_split[3], instance_uri_split[4], instance_uri_split[5], ) class RefreshAheadCache: """ Manages the information used to connect to the AlloyDB instance. Periodically calls the AlloyDB API, automatically refreshing the required information approximately 4 minutes before the previous certificate expires (every ~56 minutes). Args: instance_uri (str): The instance URI of the AlloyDB instance. ex. projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE> client (AlloyDBClient): Client used to make requests to AlloyDB APIs. keys (tuple[rsa.RSAPrivateKey, str]): Private and Public key pair. """ def __init__( self, instance_uri: str, client: AlloyDBClient, keys: asyncio.Future[tuple[rsa.RSAPrivateKey, str]], ) -> None: # validate and parse instance_uri self._project, self._region, self._cluster, self._name = _parse_instance_uri( instance_uri ) self._instance_uri = instance_uri self._client = client self._keys = keys self._refresh_rate_limiter = AsyncRateLimiter( max_capacity=2, rate=1 / 30, ) self._refresh_in_progress = asyncio.locks.Event() # For the initial refresh operation, set current = next so that # connection requests block until the first refresh is complete. self._current: asyncio.Task = self._schedule_refresh(0) self._next: asyncio.Task = self._current async def _perform_refresh(self) -> ConnectionInfo: """ Perform a refresh operation on an AlloyDB instance. Retrieves metadata and generates new client certificate required to connect securely to the AlloyDB instance. Returns: ConnectionInfo: Result of the refresh operation. """ self._refresh_in_progress.set() logger.debug( f"['{self._instance_uri}']: Connection info refresh operation started" ) try: await self._refresh_rate_limiter.acquire() connection_info = await self._client.get_connection_info( self._project, self._region, self._cluster, self._name, self._keys, ) logger.debug( f"['{self._instance_uri}']: Connection info refresh operation" " complete" ) logger.debug( f"['{self._instance_uri}']: Current certificate expiration = " f"{connection_info.expiration.isoformat()}" ) except Exception as e: logger.debug( f"['{self._instance_uri}']: Connection info refresh operation" f" failed: {str(e)}" ) raise finally: self._refresh_in_progress.clear() return connection_info def _schedule_refresh(self, delay: int) -> asyncio.Task: """ Schedule a refresh operation. Args: delay (int): Time in seconds to sleep before performing refresh. Returns: asyncio.Task[ConnectionInfo]: A task representing the scheduled refresh operation. """ return asyncio.create_task(self._refresh_operation(delay)) async def _refresh_operation(self, delay: int) -> ConnectionInfo: """ A coroutine that sleeps for the specified amount of time before running _perform_refresh. Args: delay (int): Time in seconds to sleep before performing refresh. Returns: ConnectionInfo: Refresh result for an AlloyDB instance. """ refresh_task: asyncio.Task try: if delay > 0: await asyncio.sleep(delay) refresh_task = asyncio.create_task(self._perform_refresh()) refresh_result = await refresh_task # check that refresh is valid if not await _is_valid(refresh_task): raise RefreshError( f"['{self._instance_uri}']: Invalid refresh operation. Certficate appears to be expired." ) except asyncio.CancelledError: logger.debug( f"['{self._instance_uri}']: Scheduled refresh operation cancelled" ) raise # bad refresh attempt except Exception: logger.info( f"['{self._instance_uri}']: " "An error occurred while performing refresh. " "Scheduling another refresh attempt immediately" ) # check if current refresh result is invalid (expired), # don't want to replace valid result with invalid refresh if not await _is_valid(self._current): self._current = refresh_task # schedule new refresh attempt immediately self._next = self._schedule_refresh(0) raise # if valid refresh, replace current with valid refresh result and schedule next refresh self._current = refresh_task # calculate refresh delay based on certificate expiration delay = _seconds_until_refresh(refresh_result.expiration) logger.debug( f"['{self._instance_uri}']: Connection info refresh operation" " scheduled for " f"{(datetime.now(timezone.utc) + timedelta(seconds=delay)).isoformat(timespec='seconds')} " f"(now + {timedelta(seconds=delay)})" ) self._next = self._schedule_refresh(delay) return refresh_result async def force_refresh(self) -> None: """ Schedules a new refresh operation immediately to be used for future connection attempts. """ # if next refresh is not already in progress, cancel it and schedule new one immediately if not self._refresh_in_progress.is_set(): self._next.cancel() self._next = self._schedule_refresh(0) # block all sequential connection attempts on the next refresh result if current is invalid if not await _is_valid(self._current): self._current = self._next async def connect_info(self) -> ConnectionInfo: """Retrieves ConnectionInfo instance for establishing a secure connection to the AlloyDB instance. """ return await self._current async def close(self) -> None: """ Cancel refresh tasks. """ logger.debug( f"['{self._instance_uri}']: Canceling connection info refresh" " operation tasks" ) self._current.cancel() self._next.cancel() # gracefully wait for tasks to cancel tasks = asyncio.gather(self._current, self._next, return_exceptions=True) await asyncio.wait_for(tasks, timeout=2.0)