# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import logging
import re
from typing import TYPE_CHECKING

from google.cloud.alloydb.connector.connection_info import ConnectionInfo
from google.cloud.alloydb.connector.exceptions import RefreshError
from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter
from google.cloud.alloydb.connector.refresh_utils import _is_valid
from google.cloud.alloydb.connector.refresh_utils import _seconds_until_refresh

if TYPE_CHECKING:
    from cryptography.hazmat.primitives.asymmetric import rsa

    from google.cloud.alloydb.connector.client import AlloyDBClient

logger = logging.getLogger(name=__name__)

INSTANCE_URI_REGEX = re.compile(
    "projects/([^:]+(:[^:]+)?)/locations/([^:]+)/clusters/([^:]+)/instances/([^:]+)"
)


def _parse_instance_uri(instance_uri: str) -> tuple[str, str, str, str]:
    # should take form "projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>"
    if INSTANCE_URI_REGEX.fullmatch(instance_uri) is None:
        raise ValueError(
            "Arg `instance_uri` must have "
            "format: projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>, projects/<DOMAIN>:<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>"
            f"got {instance_uri}."
        )
    instance_uri_split = INSTANCE_URI_REGEX.split(instance_uri)
    return (
        instance_uri_split[1],
        instance_uri_split[3],
        instance_uri_split[4],
        instance_uri_split[5],
    )


class RefreshAheadCache:
    """
    Manages the information used to connect to the AlloyDB instance.

    Periodically calls the AlloyDB API, automatically refreshing the
    required information approximately 4 minutes before the previous
    certificate expires (every ~56 minutes).

    Args:
        instance_uri (str): The instance URI of the AlloyDB instance.
            ex. projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>
        client (AlloyDBClient): Client used to make requests to AlloyDB APIs.
        keys (tuple[rsa.RSAPrivateKey, str]): Private and Public key pair.
    """

    def __init__(
        self,
        instance_uri: str,
        client: AlloyDBClient,
        keys: asyncio.Future[tuple[rsa.RSAPrivateKey, str]],
    ) -> None:
        # validate and parse instance_uri
        self._project, self._region, self._cluster, self._name = _parse_instance_uri(
            instance_uri
        )

        self._instance_uri = instance_uri
        self._client = client
        self._keys = keys
        self._refresh_rate_limiter = AsyncRateLimiter(
            max_capacity=2,
            rate=1 / 30,
        )
        self._refresh_in_progress = asyncio.locks.Event()
        # For the initial refresh operation, set current = next so that
        # connection requests block until the first refresh is complete.
        self._current: asyncio.Task = self._schedule_refresh(0)
        self._next: asyncio.Task = self._current

    async def _perform_refresh(self) -> ConnectionInfo:
        """
        Perform a refresh operation on an AlloyDB instance.

        Retrieves metadata and generates new client certificate
        required to connect securely to the AlloyDB instance.

        Returns:
            ConnectionInfo: Result of the refresh operation.
        """
        self._refresh_in_progress.set()
        logger.debug(
            f"['{self._instance_uri}']: Connection info refresh operation started"
        )

        try:
            await self._refresh_rate_limiter.acquire()
            connection_info = await self._client.get_connection_info(
                self._project,
                self._region,
                self._cluster,
                self._name,
                self._keys,
            )
            logger.debug(
                f"['{self._instance_uri}']: Connection info refresh operation"
                " complete"
            )
            logger.debug(
                f"['{self._instance_uri}']: Current certificate expiration = "
                f"{connection_info.expiration.isoformat()}"
            )

        except Exception as e:
            logger.debug(
                f"['{self._instance_uri}']: Connection info refresh operation"
                f" failed: {str(e)}"
            )
            raise

        finally:
            self._refresh_in_progress.clear()
        return connection_info

    def _schedule_refresh(self, delay: int) -> asyncio.Task:
        """
        Schedule a refresh operation.

        Args:
            delay (int): Time in seconds to sleep before performing refresh.

        Returns:
            asyncio.Task[ConnectionInfo]: A task representing the scheduled
                refresh operation.
        """
        return asyncio.create_task(self._refresh_operation(delay))

    async def _refresh_operation(self, delay: int) -> ConnectionInfo:
        """
        A coroutine that sleeps for the specified amount of time before
        running _perform_refresh.

        Args:
            delay (int): Time in seconds to sleep before performing refresh.

        Returns:
            ConnectionInfo: Refresh result for an AlloyDB instance.
        """
        refresh_task: asyncio.Task
        try:
            if delay > 0:
                await asyncio.sleep(delay)
            refresh_task = asyncio.create_task(self._perform_refresh())
            refresh_result = await refresh_task
            # check that refresh is valid
            if not await _is_valid(refresh_task):
                raise RefreshError(
                    f"['{self._instance_uri}']: Invalid refresh operation. Certficate appears to be expired."
                )
        except asyncio.CancelledError:
            logger.debug(
                f"['{self._instance_uri}']: Scheduled refresh operation cancelled"
            )
            raise
        # bad refresh attempt
        except Exception:
            logger.info(
                f"['{self._instance_uri}']: "
                "An error occurred while performing refresh. "
                "Scheduling another refresh attempt immediately"
            )
            # check if current refresh result is invalid (expired),
            # don't want to replace valid result with invalid refresh
            if not await _is_valid(self._current):
                self._current = refresh_task
            # schedule new refresh attempt immediately
            self._next = self._schedule_refresh(0)
            raise
        # if valid refresh, replace current with valid refresh result and schedule next refresh
        self._current = refresh_task
        # calculate refresh delay based on certificate expiration
        delay = _seconds_until_refresh(refresh_result.expiration)
        logger.debug(
            f"['{self._instance_uri}']: Connection info refresh operation"
            " scheduled for "
            f"{(datetime.now(timezone.utc) + timedelta(seconds=delay)).isoformat(timespec='seconds')} "
            f"(now + {timedelta(seconds=delay)})"
        )
        self._next = self._schedule_refresh(delay)

        return refresh_result

    async def force_refresh(self) -> None:
        """
        Schedules a new refresh operation immediately to be used
        for future connection attempts.
        """
        # if next refresh is not already in progress, cancel it and schedule new one immediately
        if not self._refresh_in_progress.is_set():
            self._next.cancel()
            self._next = self._schedule_refresh(0)
        # block all sequential connection attempts on the next refresh result if current is invalid
        if not await _is_valid(self._current):
            self._current = self._next

    async def connect_info(self) -> ConnectionInfo:
        """Retrieves ConnectionInfo instance for establishing a secure
        connection to the AlloyDB instance.
        """
        return await self._current

    async def close(self) -> None:
        """
        Cancel refresh tasks.
        """
        logger.debug(
            f"['{self._instance_uri}']: Canceling connection info refresh"
            " operation tasks"
        )
        self._current.cancel()
        self._next.cancel()
        # gracefully wait for tasks to cancel
        tasks = asyncio.gather(self._current, self._next, return_exceptions=True)
        await asyncio.wait_for(tasks, timeout=2.0)
