"""RDS Data API Connector."""

from __future__ import annotations

import datetime as dt
import logging
import time
import uuid
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast

import boto3
from typing_extensions import Literal

import awswrangler.pandas as pd
from awswrangler import _data_types, _databases, _utils, exceptions
from awswrangler._sql_utils import identifier
from awswrangler.data_api import _connector

if TYPE_CHECKING:
    from mypy_boto3_rds_data.client import BotocoreClientError  # type: ignore[attr-defined]
    from mypy_boto3_rds_data.type_defs import BatchExecuteStatementResponseTypeDef, ExecuteStatementResponseTypeDef


_logger = logging.getLogger(__name__)


_ExecuteStatementResponseType = TypeVar(
    "_ExecuteStatementResponseType", "ExecuteStatementResponseTypeDef", "BatchExecuteStatementResponseTypeDef"
)


class RdsDataApi(_connector.DataApiConnector):
    """Provides access to the RDS Data API.

    Parameters
    ----------
    resource_arn
        ARN for the RDS resource.
    database
        Target database name.
    secret_arn
        The ARN for the secret to be used for authentication.
    sleep
        Number of seconds to sleep between connection attempts to paused clusters - defaults to 0.5.
    backoff
        Factor by which to increase the sleep between connection attempts to paused clusters - defaults to 1.0.
    retries
        Maximum number of connection attempts to paused clusters - defaults to 10.
    boto3_session
        The default boto3 session will be used if **boto3_session** is ``None``.
    """

    def __init__(
        self,
        resource_arn: str,
        database: str,
        secret_arn: str = "",
        sleep: float = 0.5,
        backoff: float = 1.0,
        retries: int = 30,
        boto3_session: boto3.Session | None = None,
    ) -> None:
        super().__init__()

        self.client = _utils.client(service_name="rds-data", session=boto3_session)

        self.resource_arn = resource_arn
        self.database = database
        self.secret_arn = secret_arn
        self.wait_config = _connector.WaitConfig(sleep, backoff, retries)
        self.results: dict[str, "ExecuteStatementResponseTypeDef" | "BatchExecuteStatementResponseTypeDef"] = {}

    def close(self) -> None:
        """Close underlying endpoint connections."""
        self.client.close()

    def begin_transaction(self, database: str | None = None, schema: str | None = None) -> str:
        """Start an SQL transaction."""
        if database is None:
            database = self.database

        kwargs = {}
        if database:
            kwargs["database"] = database
        if schema:
            kwargs["schema"] = schema

        response = self.client.begin_transaction(
            resourceArn=self.resource_arn,
            secretArn=self.secret_arn,
            **kwargs,
        )
        return response["transactionId"]

    def commit_transaction(self, transaction_id: str) -> str:
        """Commit an SQL transaction."""
        response = self.client.commit_transaction(
            resourceArn=self.resource_arn,
            secretArn=self.secret_arn,
            transactionId=transaction_id,
        )
        return response["transactionStatus"]

    def rollback_transaction(self, transaction_id: str) -> str:
        """Roll back an SQL transaction."""
        response = self.client.rollback_transaction(
            resourceArn=self.resource_arn,
            secretArn=self.secret_arn,
            transactionId=transaction_id,
        )
        return response["transactionStatus"]

    def _execute_with_retry(
        self,
        sql: str,
        function: Callable[[str], _ExecuteStatementResponseType],
    ) -> str:
        sleep: float = self.wait_config.sleep
        total_tries: int = 0
        total_sleep: float = 0
        response: _ExecuteStatementResponseType | None = None
        last_exception: "BotocoreClientError" | None = None
        while total_tries < self.wait_config.retries:
            try:
                response = function(sql)
                _logger.debug(
                    "Response received after %s tries and sleeping for a total of %s seconds", total_tries, total_sleep
                )
                break
            except self.client.exceptions.BadRequestException as exception:
                last_exception = exception
                total_sleep += sleep
                _logger.debug("BadRequestException occurred: %s", exception)
                _logger.debug(
                    "Cluster may be paused - sleeping for %s seconds for a total of %s before retrying",
                    sleep,
                    total_sleep,
                )
                time.sleep(sleep)
                total_tries += 1
                sleep *= self.wait_config.backoff

        if response is None:
            _logger.exception("Maximum BadRequestException retries reached for query %s", sql)
            raise last_exception  # type: ignore[misc]

        request_id: str = uuid.uuid4().hex
        self.results[request_id] = response
        return request_id

    def _execute_statement(
        self,
        sql: str,
        database: str | None = None,
        transaction_id: str | None = None,
        parameters: list[dict[str, Any]] | None = None,
    ) -> str:
        if database is None:
            database = self.database

        additional_kwargs: dict[str, Any] = {}
        if transaction_id:
            additional_kwargs["transactionId"] = transaction_id
        if parameters:
            additional_kwargs["parameters"] = parameters

        def function(sql: str) -> "ExecuteStatementResponseTypeDef":
            return self.client.execute_statement(
                resourceArn=self.resource_arn,
                database=database,
                sql=sql,
                secretArn=self.secret_arn,
                includeResultMetadata=True,
                **additional_kwargs,
            )

        return self._execute_with_retry(sql=sql, function=function)

    def _batch_execute_statement(
        self,
        sql: str | list[str],
        database: str | None = None,
        transaction_id: str | None = None,
        parameter_sets: list[list[dict[str, Any]]] | None = None,
    ) -> str:
        if isinstance(sql, list):
            raise exceptions.InvalidArgumentType("`sql` parameter cannot be list.")

        if database is None:
            database = self.database

        additional_kwargs: dict[str, Any] = {}
        if transaction_id:
            additional_kwargs["transactionId"] = transaction_id
        if parameter_sets:
            additional_kwargs["parameterSets"] = parameter_sets

        def function(sql: str) -> "BatchExecuteStatementResponseTypeDef":
            return self.client.batch_execute_statement(
                resourceArn=self.resource_arn,
                database=database,
                sql=sql,
                secretArn=self.secret_arn,
                **additional_kwargs,
            )

        return self._execute_with_retry(sql=sql, function=function)

    def _get_statement_result(self, request_id: str) -> pd.DataFrame:
        try:
            result = cast("ExecuteStatementResponseTypeDef", self.results.pop(request_id))
        except KeyError as exception:
            raise KeyError(f"Request {request_id} not found in results {self.results}") from exception

        if "records" not in result:
            return pd.DataFrame()

        rows: list[list[Any]] = []
        column_types = [col.get("typeName") for col in result["columnMetadata"]]

        for record in result["records"]:
            row: list[Any] = [
                _connector.DataApiConnector._get_column_value(column, col_type)  # type: ignore[arg-type]
                for column, col_type in zip(record, column_types)
            ]
            rows.append(row)

        column_names: list[str] = [column["name"] for column in result["columnMetadata"]]
        dataframe = pd.DataFrame(rows, columns=column_names)
        return dataframe


def connect(
    resource_arn: str, database: str, secret_arn: str = "", boto3_session: boto3.Session | None = None, **kwargs: Any
) -> RdsDataApi:
    """Create a RDS Data API connection.

    Parameters
    ----------
    resource_arn
        ARN for the RDS resource.
    database
        Target database name.
    secret_arn
        The ARN for the secret to be used for authentication.
    boto3_session
        The default boto3 session will be used if **boto3_session** is ``None``.
    **kwargs
        Any additional kwargs are passed to the underlying RdsDataApi class.

    Returns
    -------
        A RdsDataApi connection instance that can be used with `wr.rds.data_api.read_sql_query`.
    """
    return RdsDataApi(resource_arn, database, secret_arn=secret_arn, boto3_session=boto3_session, **kwargs)


def read_sql_query(
    sql: str, con: RdsDataApi, database: str | None = None, parameters: list[dict[str, Any]] | None = None
) -> pd.DataFrame:
    """Run an SQL query on an RdsDataApi connection and return the result as a DataFrame.

    Parameters
    ----------
    sql
        SQL query to run.
    con
        A RdsDataApi connection instance
    database
        Database to run query on - defaults to the database specified by `con`.
    parameters
        A list of named parameters e.g. [{"name": "col", "value": {"stringValue": "val1"}}].

    Returns
    -------
        A Pandas DataFrame containing the query results.

    Examples
    --------
    >>> import awswrangler as wr
    >>> df = wr.data_api.rds.read_sql_query(
    >>>     sql="SELECT * FROM public.my_table",
    >>>     con=con,
    >>> )

    >>> import awswrangler as wr
    >>> df = wr.data_api.rds.read_sql_query(
    >>>     sql="SELECT * FROM public.my_table WHERE col = :name",
    >>>     con=con,
    >>>     parameters=[
    >>>        {"name": "col1", "value": {"stringValue": "val1"}}
    >>>     ],
    >>> )
    """
    return con.execute(sql, database=database, parameters=parameters)


def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, sql_mode: str) -> None:
    sql = f"DROP TABLE IF EXISTS {identifier(table, sql_mode=sql_mode)}"
    _logger.debug("Drop table query:\n%s", sql)
    con.execute(sql, database=database, transaction_id=transaction_id)


def _does_table_exist(con: RdsDataApi, table: str, database: str, transaction_id: str) -> bool:
    res = con.execute(
        "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = :table",
        parameters=[
            {
                "name": "table",
                "value": {"stringValue": table},
            },
        ],
    )
    return not res.empty


def _create_table(
    df: pd.DataFrame,
    con: RdsDataApi,
    table: str,
    database: str,
    transaction_id: str,
    mode: str,
    index: bool,
    dtype: dict[str, str] | None,
    varchar_lengths: dict[str, int] | None,
    sql_mode: str,
) -> None:
    if mode == "overwrite":
        _drop_table(con=con, table=table, database=database, transaction_id=transaction_id, sql_mode=sql_mode)
    elif _does_table_exist(con=con, table=table, database=database, transaction_id=transaction_id):
        return

    mysql_types: dict[str, str] = _data_types.database_types_from_pandas(
        df=df,
        index=index,
        dtype=dtype,
        varchar_lengths_default="TEXT",
        varchar_lengths=varchar_lengths,
        converter_func=_data_types.pyarrow2mysql,
    )
    cols_str: str = "".join([f"{identifier(k, sql_mode=sql_mode)} {v},\n" for k, v in mysql_types.items()])[:-2]
    sql = f"CREATE TABLE IF NOT EXISTS {identifier(table, sql_mode=sql_mode)} (\n{cols_str})"

    _logger.debug("Create table query:\n%s", sql)
    con.execute(sql, database=database, transaction_id=transaction_id)


def _create_value_dict(  # noqa: PLR0911
    value: Any,
) -> tuple[dict[str, Any], str | None]:
    if value is None or pd.isnull(value):
        return {"isNull": True}, None

    if isinstance(value, bool):
        return {"booleanValue": value}, None

    if isinstance(value, int):
        return {"longValue": value}, None

    if isinstance(value, float):
        return {"doubleValue": value}, None

    if isinstance(value, str):
        return {"stringValue": value}, None

    if isinstance(value, bytes):
        return {"blobValue": value}, None

    if isinstance(value, dt.datetime):
        return {"stringValue": str(value)}, "TIMESTAMP"

    if isinstance(value, dt.date):
        return {"stringValue": str(value)}, "DATE"

    if isinstance(value, dt.time):
        return {"stringValue": str(value)}, "TIME"

    if isinstance(value, Decimal):
        return {"stringValue": str(value)}, "DECIMAL"

    if isinstance(value, uuid.UUID):
        return {"stringValue": str(value)}, "UUID"

    raise exceptions.InvalidArgumentType(f"Value {value} not supported.")


def _generate_parameters(columns: list[str], values: list[Any]) -> list[dict[str, Any]]:
    parameter_list = []

    for col, value in zip(columns, values):
        value, type_hint = _create_value_dict(value)  # noqa: PLW2901

        parameter = {
            "name": col,
            "value": value,
        }
        if type_hint:
            parameter["typeHint"] = type_hint

        parameter_list.append(parameter)

    return parameter_list


def _generate_parameter_sets(df: pd.DataFrame) -> list[list[dict[str, Any]]]:
    parameter_sets = []

    columns = df.columns.tolist()
    for values in df.values.tolist():
        parameter_sets.append(_generate_parameters(columns, values))

    return parameter_sets


def to_sql(
    df: pd.DataFrame,
    con: RdsDataApi,
    table: str,
    database: str,
    mode: Literal["append", "overwrite"] = "append",
    index: bool = False,
    dtype: dict[str, str] | None = None,
    varchar_lengths: dict[str, int] | None = None,
    use_column_names: bool = False,
    chunksize: int = 200,
    sql_mode: str = "mysql",
) -> None:
    """
    Insert data using an SQL query on a Data API connection.

    Parameters
    ----------
    df
        `Pandas DataFrame <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html>`_
    con
        A RdsDataApi connection instance
    database
        Database to run query on - defaults to the database specified by `con`.
    table
        Table name
    mode
        `append` (inserts new records into table), `overwrite` (drops table and recreates)
    index
        True to store the DataFrame index as a column in the table,
        otherwise False to ignore it.
    dtype
        Dictionary of columns names and MySQL types to be casted.
        Useful when you have columns with undetermined or mixed data types.
        (e.g. ```{'col name': 'TEXT', 'col2 name': 'FLOAT'}```)
    varchar_lengths
        Dict of VARCHAR length by columns. (e.g. ```{"col1": 10, "col5": 200}```).
    use_column_names
        If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
        E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
        inserted into the database columns `col1` and `col3`.
    chunksize
        Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
    sql_mode
        "mysql" for default MySQL identifiers (backticks) or "ansi" for ANSI-compatible identifiers (double quotes).
    """
    if df.empty is True:
        raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")

    _databases.validate_mode(mode=mode, allowed_modes=["append", "overwrite"])

    transaction_id: str | None = None
    try:
        transaction_id = con.begin_transaction(database=database)

        _create_table(
            df=df,
            con=con,
            table=table,
            database=database,
            transaction_id=transaction_id,
            mode=mode,
            index=index,
            dtype=dtype,
            varchar_lengths=varchar_lengths,
            sql_mode=sql_mode,
        )

        if index:
            df = df.reset_index(level=df.index.names)

        if use_column_names:
            insertion_columns = "(" + ", ".join([f"{identifier(col, sql_mode=sql_mode)}" for col in df.columns]) + ")"
        else:
            insertion_columns = ""

        placeholders = ", ".join([f":{col}" for col in df.columns])

        sql = f"INSERT INTO {identifier(table, sql_mode=sql_mode)} {insertion_columns} VALUES ({placeholders})"
        parameter_sets = _generate_parameter_sets(df)

        for parameter_sets_chunk in _utils.chunkify(parameter_sets, max_length=chunksize):
            con.batch_execute(
                sql=sql,
                database=database,
                transaction_id=transaction_id,
                parameter_sets=parameter_sets_chunk,
            )

    except Exception as ex:
        if transaction_id:
            con.rollback_transaction(transaction_id=transaction_id)
        _logger.error(ex)
        raise

    else:
        con.commit_transaction(transaction_id=transaction_id)
