google/cloud/sql/connector/refresh_utils.py (61 lines of code) (raw):

""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from __future__ import annotations import asyncio import copy import datetime import logging import random from typing import Any, Callable import aiohttp from google.auth.credentials import Credentials from google.auth.credentials import Scoped import google.auth.transport.requests logger = logging.getLogger(name=__name__) # _refresh_buffer is the amount of time before a refresh's result expires # that a new refresh operation begins. _refresh_buffer: int = 4 * 60 # 4 minutes def _seconds_until_refresh( expiration: datetime.datetime, ) -> int: """ Calculates the duration to wait before starting the next refresh. Usually the duration will be half of the time until certificate expiration. Args: expiration (datetime.datetime): The expiration time of the certificate. Returns: int: Time in seconds to wait before performing next refresh. """ duration = int( (expiration - datetime.datetime.now(datetime.timezone.utc)).total_seconds() ) # if certificate duration is less than 1 hour if duration < 3600: # something is wrong with certificate, refresh now if duration < _refresh_buffer: return 0 # otherwise wait until 4 minutes before expiration for next refresh return duration - _refresh_buffer return duration // 2 async def _is_valid(task: asyncio.Task) -> bool: try: metadata = await task # only valid if now is before the cert expires if datetime.datetime.now(datetime.timezone.utc) < metadata.expiration: return True except Exception: # supress any errors from task logger.debug("Current instance metadata is invalid.") return False def _downscope_credentials( credentials: Credentials, scopes: list[str] = ["https://www.googleapis.com/auth/sqlservice.login"], ) -> Credentials: """Generate a down-scoped credential. Args: credentials (google.auth.credentials.Credentials): Credentials object used to generate down-scoped credentials. scopes (list[str]): List of Google scopes to include in down-scoped credentials object. Returns: google.auth.credentials.Credentials: Down-scoped credentials object. """ # credentials sourced from a service account or metadata are children of # Scoped class and are capable of being re-scoped if isinstance(credentials, Scoped): scoped_creds = credentials.with_scopes(scopes=scopes) # authenticated user credentials can not be re-scoped else: # create shallow copy to not overwrite scopes on default credentials scoped_creds = copy.copy(credentials) # overwrite '_scopes' to down-scope user credentials # Cloud SDK reference: https://github.com/google-cloud-sdk-unofficial/google-cloud-sdk/blob/93920ccb6d2cce0fe6d1ce841e9e33410551d66b/lib/googlecloudsdk/command_lib/sql/generate_login_token_util.py#L116 scoped_creds._scopes = scopes # down-scoped credentials require refresh, are invalid after being re-scoped request = google.auth.transport.requests.Request() scoped_creds.refresh(request) return scoped_creds def _exponential_backoff(attempt: int) -> float: """Calculates a duration to backoff in milliseconds based on the attempt i. The formula is: base * multi^(attempt + 1 + random) With base = 200ms and multi = 1.618, and random = [0.0, 1.0), the backoff values would fall between the following low and high ends: Attempt Low (ms) High (ms) 0 324 524 1 524 847 2 847 1371 3 1371 2218 4 2218 3588 The theoretical worst case scenario would have a client wait 8.5s in total for an API request to complete (with the first four attempts failing, and the fifth succeeding). """ base = 200 multi = 1.618 exp = attempt + 1 + random.random() return base * pow(multi, exp) async def retry_50x( request_coro: Callable, *args: Any, **kwargs: Any ) -> aiohttp.ClientResponse: """Retry any 50x HTTP response up to X number of times.""" max_retries = 5 for i in range(max_retries): resp = await request_coro(*args, **kwargs) # backoff for any 50X errors if resp.status >= 500 and i < max_retries: # calculate backoff time backoff = _exponential_backoff(i) await asyncio.sleep(backoff / 1000) else: break return resp