# mypy: disable-error-code=name-defined
"""Amazon Redshift Utils Module (PRIVATE)."""

from __future__ import annotations

import json
import logging
import uuid
from typing import TYPE_CHECKING, Literal

import boto3
import botocore
import pandas as pd

from awswrangler import _data_types, _sql_utils, _utils, exceptions, s3

if TYPE_CHECKING:
    try:
        import redshift_connector
    except ImportError:
        pass
else:
    redshift_connector = _utils.import_optional_dependency("redshift_connector")

_logger: logging.Logger = logging.getLogger(__name__)


_RS_DISTSTYLES: list[str] = ["AUTO", "EVEN", "ALL", "KEY"]
_RS_SORTSTYLES: list[str] = ["COMPOUND", "INTERLEAVED"]


def _identifier(sql: str) -> str:
    return _sql_utils.identifier(sql, sql_mode="ansi")


def _make_s3_auth_string(
    aws_access_key_id: str | None = None,
    aws_secret_access_key: str | None = None,
    aws_session_token: str | None = None,
    iam_role: str | None = None,
    boto3_session: boto3.Session | None = None,
) -> str:
    if aws_access_key_id is not None and aws_secret_access_key is not None:
        auth_str: str = f"ACCESS_KEY_ID '{aws_access_key_id}'\nSECRET_ACCESS_KEY '{aws_secret_access_key}'\n"
        if aws_session_token is not None:
            auth_str += f"SESSION_TOKEN '{aws_session_token}'\n"
    elif iam_role is not None:
        auth_str = f"IAM_ROLE '{iam_role}'\n"
    else:
        _logger.debug("Attempting to get S3 authorization credentials from boto3 session.")
        credentials: botocore.credentials.ReadOnlyCredentials
        credentials = _utils.get_credentials_from_session(boto3_session=boto3_session)
        if credentials.access_key is None or credentials.secret_key is None:
            raise exceptions.InvalidArgument(
                "One of IAM Role or AWS ACCESS_KEY_ID and SECRET_ACCESS_KEY must be "
                "given. Unable to find ACCESS_KEY_ID and SECRET_ACCESS_KEY in boto3 "
                "session."
            )

        auth_str = f"ACCESS_KEY_ID '{credentials.access_key}'\nSECRET_ACCESS_KEY '{credentials.secret_key}'\n"
        if credentials.token is not None:
            auth_str += f"SESSION_TOKEN '{credentials.token}'\n"

    return auth_str


def _begin_transaction(cursor: "redshift_connector.Cursor") -> None:
    sql = "BEGIN TRANSACTION"
    _logger.debug("Executing begin transaction query:\n%s", sql)
    cursor.execute(sql)


def _drop_table(cursor: "redshift_connector.Cursor", schema: str | None, table: str, cascade: bool = False) -> None:
    schema_str = f'"{schema}".' if schema else ""
    cascade_str = " CASCADE" if cascade else ""
    sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"{cascade_str}'
    _logger.debug("Executing drop table query:\n%s", sql)
    cursor.execute(sql)


def _truncate_table(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> None:
    if schema:
        sql = f"TRUNCATE TABLE {_identifier(schema)}.{_identifier(table)}"
    else:
        sql = f"TRUNCATE TABLE {_identifier(table)}"
    _logger.debug("Executing truncate table query:\n%s", sql)
    cursor.execute(sql)


def _delete_all(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> None:
    if schema:
        sql = f"DELETE FROM {_identifier(schema)}.{_identifier(table)}"
    else:
        sql = f"DELETE FROM {_identifier(table)}"
    _logger.debug("Executing delete query:\n%s", sql)
    cursor.execute(sql)


def _get_primary_keys(cursor: "redshift_connector.Cursor", schema: str, table: str) -> list[str]:
    sql = f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{schema}' AND tablename = '{table}'"
    _logger.debug("Executing select query:\n%s", sql)
    cursor.execute(sql)
    result: str = cursor.fetchall()[0][0]
    rfields: list[str] = result.split("(")[1].strip(")").split(",")
    fields: list[str] = [field.strip().strip('"') for field in rfields]
    return fields


def _get_table_columns(cursor: "redshift_connector.Cursor", schema: str, table: str) -> list[str]:
    sql = f"SELECT column_name FROM svv_columns\n WHERE table_schema = '{schema}' AND table_name = '{table}'"
    _logger.debug("Executing select query:\n%s", sql)
    cursor.execute(sql)
    result: tuple[list[str]] = cursor.fetchall()
    columns = ["".join(lst) for lst in result]
    return columns


def _add_table_columns(
    cursor: "redshift_connector.Cursor", schema: str, table: str, new_columns: dict[str, str]
) -> None:
    for column_name, column_type in new_columns.items():
        sql = (
            f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}"
            f"\nADD COLUMN {_identifier(column_name)} {column_type};"
        )
        _logger.debug("Executing alter query:\n%s", sql)
        cursor.execute(sql)


def _does_table_exist(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> bool:
    schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
    sql = (
        f"SELECT true WHERE EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = '{table}');"
    )
    _logger.debug("Executing select query:\n%s", sql)
    cursor.execute(sql)
    return len(cursor.fetchall()) > 0


def _get_paths_from_manifest(path: str, boto3_session: boto3.Session | None = None) -> list[str]:
    client_s3 = _utils.client(service_name="s3", session=boto3_session)
    bucket, key = _utils.parse_path(path)
    manifest_content = json.loads(client_s3.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8"))
    paths = [path["url"] for path in manifest_content["entries"]]
    _logger.debug("Read %d paths from manifest file in: %s", len(paths), path)
    return paths


def _get_parameter_setting(cursor: "redshift_connector.Cursor", parameter_name: str) -> str:
    sql = f"SHOW {parameter_name}"
    _logger.debug("Executing select query:\n%s", sql)
    cursor.execute(sql)
    result = cursor.fetchall()
    status = str(result[0][0])
    _logger.debug(f"{parameter_name}='{status}'")
    return status


def _lock(
    cursor: "redshift_connector.Cursor",
    table_names: list[str],
    schema: str | None = None,
) -> None:
    tables = ", ".join(
        [(f"{_identifier(schema)}.{_identifier(table)}" if schema else _identifier(table)) for table in table_names]
    )
    sql: str = f"LOCK {tables};\n"
    _logger.debug("Executing lock query:\n%s", sql)
    cursor.execute(sql)


def _upsert(
    cursor: "redshift_connector.Cursor",
    table: str,
    temp_table: str,
    schema: str,
    primary_keys: list[str] | None = None,
    precombine_key: str | None = None,
    column_names: list[str] | None = None,
) -> None:
    if not primary_keys:
        primary_keys = _get_primary_keys(cursor=cursor, schema=schema, table=table)
    _logger.debug("primary_keys: %s", primary_keys)
    if not primary_keys:
        raise exceptions.InvalidRedshiftPrimaryKeys()
    equals_clause: str = f"{_identifier(table)}.%s = {_identifier(temp_table)}.%s"
    join_clause: str = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys])
    if precombine_key:
        delete_from_target_filter: str = (
            f"AND {_identifier(table)}.{precombine_key} <= {_identifier(temp_table)}.{precombine_key}"
        )
        delete_from_temp_filter: str = (
            f"AND {_identifier(table)}.{precombine_key} > {_identifier(temp_table)}.{precombine_key}"
        )
        target_del_sql: str = f"DELETE FROM {_identifier(schema)}.{_identifier(table)} USING {_identifier(temp_table)} WHERE {join_clause} {delete_from_target_filter}"
        _logger.debug("Executing delete query:\n%s", target_del_sql)
        cursor.execute(target_del_sql)
        source_del_sql: str = f"DELETE FROM {_identifier(temp_table)} USING {_identifier(schema)}.{_identifier(table)} WHERE {join_clause} {delete_from_temp_filter}"
        _logger.debug("Executing delete query:\n%s", source_del_sql)
        cursor.execute(source_del_sql)
    else:
        sql: str = f"DELETE FROM {_identifier(schema)}.{_identifier(table)} USING {_identifier(temp_table)} WHERE {join_clause}"
        _logger.debug("Executing delete query:\n%s", sql)
        cursor.execute(sql)
    if column_names:
        column_names_str = ",".join(column_names)
        insert_sql = f"INSERT INTO {_identifier(schema)}.{_identifier(table)}({column_names_str}) SELECT {column_names_str} FROM {_identifier(temp_table)}"
    else:
        insert_sql = f"INSERT INTO {_identifier(schema)}.{_identifier(table)} SELECT * FROM {_identifier(temp_table)}"
    _logger.debug("Executing insert query:\n%s", insert_sql)
    cursor.execute(insert_sql)
    _drop_table(cursor=cursor, schema=schema, table=temp_table)


def _validate_parameters(
    redshift_types: dict[str, str],
    diststyle: str,
    distkey: str | None,
    sortstyle: str,
    sortkey: list[str] | None,
    primary_keys: list[str] | None,
) -> None:
    if diststyle not in _RS_DISTSTYLES:
        raise exceptions.InvalidRedshiftDiststyle(f"diststyle must be in {_RS_DISTSTYLES}")
    cols = list(redshift_types.keys())
    _logger.debug("Redshift columns: %s", cols)
    if (diststyle == "KEY") and (not distkey):
        raise exceptions.InvalidRedshiftDistkey("You must pass a distkey if you intend to use KEY diststyle")
    if distkey and distkey not in cols:
        raise exceptions.InvalidRedshiftDistkey(f"distkey ({distkey}) must be in the columns list: {cols})")
    if sortstyle and sortstyle not in _RS_SORTSTYLES:
        raise exceptions.InvalidRedshiftSortstyle(f"sortstyle must be in {_RS_SORTSTYLES}")
    if sortkey:
        if not isinstance(sortkey, list):
            raise exceptions.InvalidRedshiftSortkey(
                f"sortkey must be a List of items in the columns list: {cols}. Currently value: {sortkey}"
            )
        for key in sortkey:
            if key not in cols:
                raise exceptions.InvalidRedshiftSortkey(
                    f"sortkey must be a List of items in the columns list: {cols}. Currently value: {key}"
                )
    if primary_keys:
        if not isinstance(primary_keys, list):
            raise exceptions.InvalidArgumentType(
                f"""
                    primary keys should be of type list[str].
                    Current value: {primary_keys} is of type {type(primary_keys)}
                """
            )


def _redshift_types_from_path(
    path: str | list[str],
    data_format: Literal["parquet", "orc"],
    varchar_lengths_default: int,
    varchar_lengths: dict[str, int] | None,
    parquet_infer_sampling: float,
    path_suffix: str | None,
    path_ignore_suffix: str | list[str] | None,
    use_threads: bool | int,
    boto3_session: boto3.Session | None,
    s3_additional_kwargs: dict[str, str] | None,
) -> dict[str, str]:
    """Extract Redshift data types from a Pandas DataFrame."""
    _varchar_lengths: dict[str, int] = {} if varchar_lengths is None else varchar_lengths
    _logger.debug("Scanning parquet schemas in S3 path: %s", path)
    if data_format == "orc":
        athena_types, _ = s3.read_orc_metadata(
            path=path,
            path_suffix=path_suffix,
            path_ignore_suffix=path_ignore_suffix,
            dataset=False,
            use_threads=use_threads,
            boto3_session=boto3_session,
            s3_additional_kwargs=s3_additional_kwargs,
        )
    else:
        athena_types, _ = s3.read_parquet_metadata(
            path=path,
            sampling=parquet_infer_sampling,
            path_suffix=path_suffix,
            path_ignore_suffix=path_ignore_suffix,
            dataset=False,
            use_threads=use_threads,
            boto3_session=boto3_session,
            s3_additional_kwargs=s3_additional_kwargs,
        )
    _logger.debug("Parquet metadata types: %s", athena_types)
    redshift_types: dict[str, str] = {}
    for col_name, col_type in athena_types.items():
        length: int = _varchar_lengths[col_name] if col_name in _varchar_lengths else varchar_lengths_default
        redshift_types[col_name] = _data_types.athena2redshift(dtype=col_type, varchar_length=length)
    _logger.debug("Converted redshift types: %s", redshift_types)
    return redshift_types


def _get_rsh_columns_types(
    df: pd.DataFrame | None,
    path: str | list[str] | None,
    index: bool,
    dtype: dict[str, str] | None,
    varchar_lengths_default: int,
    varchar_lengths: dict[str, int] | None,
    data_format: Literal["parquet", "orc", "csv"] = "parquet",
    redshift_column_types: dict[str, str] | None = None,
    parquet_infer_sampling: float = 1.0,
    path_suffix: str | None = None,
    path_ignore_suffix: str | list[str] | None = None,
    manifest: bool | None = False,
    use_threads: bool | int = True,
    boto3_session: boto3.Session | None = None,
    s3_additional_kwargs: dict[str, str] | None = None,
) -> dict[str, str]:
    if df is not None:
        redshift_types: dict[str, str] = _data_types.database_types_from_pandas(
            df=df,
            index=index,
            dtype=dtype,
            varchar_lengths_default=varchar_lengths_default,
            varchar_lengths=varchar_lengths,
            converter_func=_data_types.pyarrow2redshift,
        )
        _logger.debug("Converted redshift types from pandas: %s", redshift_types)
    elif path is not None:
        if manifest:
            if not isinstance(path, str):
                raise TypeError(
                    f"""type: {type(path)} is not a valid type for 'path' when 'manifest' is set to True;
                    must be a string"""
                )
            path = _get_paths_from_manifest(
                path=path,
                boto3_session=boto3_session,
            )

        if data_format in ["parquet", "orc"]:
            redshift_types = _redshift_types_from_path(
                path=path,
                data_format=data_format,  # type: ignore[arg-type]
                varchar_lengths_default=varchar_lengths_default,
                varchar_lengths=varchar_lengths,
                parquet_infer_sampling=parquet_infer_sampling,
                path_suffix=path_suffix,
                path_ignore_suffix=path_ignore_suffix,
                use_threads=use_threads,
                boto3_session=boto3_session,
                s3_additional_kwargs=s3_additional_kwargs,
            )
        else:
            if redshift_column_types is None:
                raise ValueError(
                    "redshift_column_types is None. It must be specified for files formats other than Parquet or ORC."
                )
            redshift_types = redshift_column_types
    else:
        raise ValueError("df and path are None. You MUST pass at least one.")
    return redshift_types


def _add_new_table_columns(
    cursor: "redshift_connector.Cursor", schema: str, table: str, redshift_columns_types: dict[str, str]
) -> None:
    # Check if Redshift is configured as case sensitive or not
    is_case_sensitive = False
    if _get_parameter_setting(cursor=cursor, parameter_name="enable_case_sensitive_identifier").lower() in [
        "on",
        "true",
    ]:
        is_case_sensitive = True

    # If it is case-insensitive, convert all the DataFrame columns names to lowercase before performing the comparison
    if is_case_sensitive is False:
        redshift_columns_types = {key.lower(): value for key, value in redshift_columns_types.items()}
    actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table))
    new_df_columns = {key: value for key, value in redshift_columns_types.items() if key not in actual_table_columns}

    _add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns)


def _create_table(  # noqa: PLR0913
    df: pd.DataFrame | None,
    path: str | list[str] | None,
    con: "redshift_connector.Connection",
    cursor: "redshift_connector.Cursor",
    table: str,
    schema: str,
    mode: str,
    overwrite_method: str,
    index: bool,
    dtype: dict[str, str] | None,
    diststyle: str,
    sortstyle: str,
    distkey: str | None,
    sortkey: list[str] | None,
    primary_keys: list[str] | None,
    varchar_lengths_default: int,
    varchar_lengths: dict[str, int] | None,
    data_format: Literal["parquet", "orc", "csv"] = "parquet",
    redshift_column_types: dict[str, str] | None = None,
    parquet_infer_sampling: float = 1.0,
    path_suffix: str | None = None,
    path_ignore_suffix: str | list[str] | None = None,
    manifest: bool | None = False,
    use_threads: bool | int = True,
    boto3_session: boto3.Session | None = None,
    s3_additional_kwargs: dict[str, str] | None = None,
    lock: bool = False,
) -> tuple[str, str | None]:
    _logger.debug("Creating table %s with mode %s, and overwrite method %s", table, mode, overwrite_method)
    if mode == "overwrite":
        if overwrite_method == "truncate":
            try:
                # Truncate commits current transaction, if successful.
                # Fast, but not atomic.
                _truncate_table(cursor=cursor, schema=schema, table=table)
            except redshift_connector.error.ProgrammingError as e:
                # Caught "relation does not exist".
                if e.args[0]["C"] != "42P01":
                    raise e
                _logger.debug(str(e))
                con.rollback()
            _begin_transaction(cursor=cursor)
            if lock:
                _lock(cursor, [table], schema=schema)
        elif overwrite_method == "delete":
            if _does_table_exist(cursor=cursor, schema=schema, table=table):
                if lock:
                    _lock(cursor, [table], schema=schema)
                # Atomic, but slow.
                _delete_all(cursor=cursor, schema=schema, table=table)
        else:
            # Fast, atomic, but either fails if there are any dependent views or, in cascade mode, deletes them.
            _drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade"))
            # No point in locking here, the oid will change.
    elif _does_table_exist(cursor=cursor, schema=schema, table=table) is True:
        _logger.debug("Table %s exists", table)
        if lock:
            _lock(cursor, [table], schema=schema)
        if mode == "upsert":
            guid: str = uuid.uuid4().hex
            temp_table: str = f"temp_redshift_{guid}"
            sql: str = f"CREATE TEMPORARY TABLE {temp_table} (LIKE {_identifier(schema)}.{_identifier(table)})"
            _logger.debug("Executing create temporary table query:\n%s", sql)
            cursor.execute(sql)
            return temp_table, None
        return table, schema
    diststyle = diststyle.upper() if diststyle else "AUTO"
    sortstyle = sortstyle.upper() if sortstyle else "COMPOUND"

    redshift_types = _get_rsh_columns_types(
        df=df,
        path=path,
        index=index,
        dtype=dtype,
        varchar_lengths_default=varchar_lengths_default,
        varchar_lengths=varchar_lengths,
        parquet_infer_sampling=parquet_infer_sampling,
        path_suffix=path_suffix,
        path_ignore_suffix=path_ignore_suffix,
        use_threads=use_threads,
        boto3_session=boto3_session,
        s3_additional_kwargs=s3_additional_kwargs,
        data_format=data_format,
        redshift_column_types=redshift_column_types,
        manifest=manifest,
    )
    _validate_parameters(
        redshift_types=redshift_types,
        diststyle=diststyle,
        distkey=distkey,
        sortstyle=sortstyle,
        sortkey=sortkey,
        primary_keys=primary_keys,
    )
    cols_str: str = "".join([f'"{k}" {v},\n' for k, v in redshift_types.items()])[:-2]
    primary_keys_str: str = (
        ",\nPRIMARY KEY ({})".format(", ".join('"' + pk + '"' for pk in primary_keys)) if primary_keys else ""
    )
    distkey_str: str = f"\nDISTKEY({distkey})" if distkey and diststyle == "KEY" else ""
    sortkey_str: str = f"\n{sortstyle} SORTKEY({','.join(sortkey)})" if sortkey else ""
    sql = (
        f"CREATE TABLE IF NOT EXISTS {_identifier(schema)}.{_identifier(table)} (\n"
        f"{cols_str}"
        f"{primary_keys_str}"
        f")\nDISTSTYLE {diststyle}"
        f"{distkey_str}"
        f"{sortkey_str}"
    )
    _logger.debug("Executing create table query:\n%s", sql)
    cursor.execute(sql)
    _logger.info("Created table %s", table)
    if lock:
        _lock(cursor, [table], schema=schema)
    return table, schema
